Unverified Commit dcd331ab authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

chore: 400 Error Code for Bad Completion and ChatCompletion Request (#3038)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
parent b1186aee
......@@ -31,6 +31,7 @@ use super::{
metrics::{Endpoint, ResponseMetricCollector},
service_v2,
};
use crate::engines::ValidateRequest;
use crate::protocols::openai::chat_completions::aggregator::ChatCompletionAggregator;
use crate::protocols::openai::{
ParsingOptions,
......@@ -254,6 +255,8 @@ async fn completions(
// return a 503 if the service is not ready
check_ready(&state)?;
validate_completion_fields_generic(&request)?;
let request_id = request.id().to_string();
// todo - decide on default
......@@ -468,6 +471,9 @@ async fn chat_completions(
// Handle required fields like messages shouldn't be empty.
validate_chat_completion_required_fields(&request)?;
// Handle Rest of Validation Errors
validate_chat_completion_fields_generic(&request)?;
// Apply template values if present
if let Some(template) = template {
if request.inner.model.is_empty() {
......@@ -630,6 +636,36 @@ pub fn validate_chat_completion_required_fields(
Ok(())
}
/// Validates a chat completion request and returns an error response if validation fails.
///
/// This function calls the `validate` method implemented for `NvCreateChatCompletionRequest`.
/// If validation fails, it maps the error into an OpenAI-compatible error response.
pub fn validate_chat_completion_fields_generic(
request: &NvCreateChatCompletionRequest,
) -> Result<(), ErrorResponse> {
request.validate().map_err(|e| {
ErrorMessage::from_http_error(HttpError {
code: 400,
message: e.to_string(),
})
})
}
/// Validates a completion request and returns an error response if validation fails.
///
/// This function calls the `validate` method implemented for `NvCreateCompletionRequest`.
/// If validation fails, it maps the error into an OpenAI-compatible error response.
pub fn validate_completion_fields_generic(
request: &NvCreateCompletionRequest,
) -> Result<(), ErrorResponse> {
request.validate().map_err(|e| {
ErrorMessage::from_http_error(HttpError {
code: 400,
message: e.to_string(),
})
})
}
/// OpenAI Responses Request Handler
///
/// This method will handle the incoming request for the /v1/responses endpoint.
......@@ -1096,6 +1132,12 @@ pub fn responses_router(
mod tests {
use std::collections::HashMap;
use super::*;
use crate::discovery::ModelManagerError;
use crate::protocols::openai::chat_completions::NvCreateChatCompletionRequest;
use crate::protocols::openai::common_ext::CommonExt;
use crate::protocols::openai::completions::NvCreateCompletionRequest;
use crate::protocols::openai::responses::NvCreateResponse;
use dynamo_async_openai::types::responses::{
CreateResponse, Input, InputContent, InputItem, InputMessage, PromptConfig,
Role as ResponseRole, ServiceTier, TextConfig, TextResponseFormat, ToolChoice,
......@@ -1104,12 +1146,9 @@ mod tests {
use dynamo_async_openai::types::{
ChatCompletionRequestMessage, ChatCompletionRequestUserMessage,
ChatCompletionRequestUserMessageContent, CreateChatCompletionRequest,
CreateCompletionRequest,
};
use super::*;
use crate::discovery::ModelManagerError;
use crate::protocols::openai::responses::NvCreateResponse;
const BACKUP_ERROR_MESSAGE: &str = "Failed to generate completions";
fn http_error_from_engine(code: u16) -> Result<(), anyhow::Error> {
......@@ -1349,4 +1388,316 @@ mod tests {
let result = validate_chat_completion_required_fields(&request);
assert!(result.is_ok());
}
#[test]
// Test for all Bad Requests Example for Chat Completion
// 1. Echo: Should be a boolean : Not Done
// 2. Frequency Penalty: Should be a float between -2.0 and 2.0 : Done
// 3. logprobs: Done
// 4. Model Format: Should be a string : Not Done
// 5. Prompt or Messages Validation
// 6. Max Tokens: Should be a positive integer
// 7. Presence Penalty: Should be a float between -2.0 and 2.0 : Done
// 8. Stop : Should be a string or an array of strings : Not Done
// 9. Invalid or Out of range temperature: Done
// 10.Invalid or out of range top_p: Done
// 11. Repetition Penalty: Should be a float between 0.0 and 2.0 : Done
// 12. Logprobs: Should be a positive integer between 0 and 5 : Done
// invalid or non existing user : Only empty string is not allowed validation is there. How can we check non-extisting user ?
// add_special_tokens null or invalid : Not Done
// guided_whitespace_pattern null or invalid : Not Done
// "response_format": { "type": "invalid_format" } : Not Done
// "logit_bias": { "invalid_token": "not_a_number" }, : Partial Validation is already there
fn test_bad_base_request_for_completion() {
// Frequency Penalty: Should be a float between -2.0 and 2.0
let request = NvCreateCompletionRequest {
inner: CreateCompletionRequest {
model: "test-model".to_string(),
prompt: "Hello".into(),
frequency_penalty: Some(-3.0),
..Default::default()
},
common: Default::default(),
nvext: None,
};
let result = validate_completion_fields_generic(&request);
assert!(result.is_err());
if let Err((status, error_response)) = result {
assert_eq!(status, StatusCode::BAD_REQUEST);
assert_eq!(
error_response.error,
"Frequency penalty must be between -2 and 2, got -3"
);
}
// Presence Penalty: Should be a float between -2.0 and 2.0
let request = NvCreateCompletionRequest {
inner: CreateCompletionRequest {
model: "test-model".to_string(),
prompt: "Hello".into(),
presence_penalty: Some(-3.0),
..Default::default()
},
common: Default::default(),
nvext: None,
};
let result = validate_completion_fields_generic(&request);
assert!(result.is_err());
if let Err((status, error_response)) = result {
assert_eq!(status, StatusCode::BAD_REQUEST);
assert_eq!(
error_response.error,
"Presence penalty must be between -2 and 2, got -3"
);
}
// Temperature: Should be a float between 0.0 and 2.0
let request = NvCreateCompletionRequest {
inner: CreateCompletionRequest {
model: "test-model".to_string(),
prompt: "Hello".into(),
temperature: Some(-3.0),
..Default::default()
},
common: Default::default(),
nvext: None,
};
let result = validate_completion_fields_generic(&request);
assert!(result.is_err());
if let Err((status, error_response)) = result {
assert_eq!(status, StatusCode::BAD_REQUEST);
assert_eq!(
error_response.error,
"Temperature must be between 0 and 2, got -3"
);
}
// Top P: Should be a float between 0.0 and 1.0
let request = NvCreateCompletionRequest {
inner: CreateCompletionRequest {
model: "test-model".to_string(),
prompt: "Hello".into(),
top_p: Some(-3.0),
..Default::default()
},
common: Default::default(),
nvext: None,
};
let result = validate_completion_fields_generic(&request);
assert!(result.is_err());
if let Err((status, error_response)) = result {
assert_eq!(status, StatusCode::BAD_REQUEST);
assert_eq!(
error_response.error,
"Top_p must be between 0 and 1, got -3"
);
}
// Repetition Penalty: Should be a float between 0.0 and 2.0
let request = NvCreateCompletionRequest {
inner: CreateCompletionRequest {
model: "test-model".to_string(),
prompt: "Hello".into(),
..Default::default()
},
common: CommonExt::builder()
.repetition_penalty(-3.0)
.build()
.unwrap(),
nvext: None,
};
let result = validate_completion_fields_generic(&request);
assert!(result.is_err());
if let Err((status, error_response)) = result {
assert_eq!(status, StatusCode::BAD_REQUEST);
assert_eq!(
error_response.error,
"Repetition penalty must be between 0 and 2, got -3"
);
}
// Logprobs: Should be a positive integer between 0 and 5
let request = NvCreateCompletionRequest {
inner: CreateCompletionRequest {
model: "test-model".to_string(),
prompt: "Hello".into(),
logprobs: Some(6),
..Default::default()
},
common: Default::default(),
nvext: None,
};
let result = validate_completion_fields_generic(&request);
assert!(result.is_err());
if let Err((status, error_response)) = result {
assert_eq!(status, StatusCode::BAD_REQUEST);
assert_eq!(
error_response.error,
"Logprobs must be between 0 and 5, got 6"
);
}
}
#[test]
fn test_bad_base_request_for_chatcompletion() {
// Frequency Penalty: Should be a float between -2.0 and 2.0
let request = NvCreateChatCompletionRequest {
inner: CreateChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text("Hello".to_string()),
name: None,
},
)],
frequency_penalty: Some(-3.0),
..Default::default()
},
common: Default::default(),
nvext: None,
};
let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err());
if let Err((status, error_response)) = result {
assert_eq!(status, StatusCode::BAD_REQUEST);
assert_eq!(
error_response.error,
"Frequency penalty must be between -2 and 2, got -3"
);
}
// Presence Penalty: Should be a float between -2.0 and 2.0
let request = NvCreateChatCompletionRequest {
inner: CreateChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text("Hello".to_string()),
name: None,
},
)],
presence_penalty: Some(-3.0),
..Default::default()
},
common: Default::default(),
nvext: None,
};
let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err());
if let Err((status, error_response)) = result {
assert_eq!(status, StatusCode::BAD_REQUEST);
assert_eq!(
error_response.error,
"Presence penalty must be between -2 and 2, got -3"
);
}
// Temperature: Should be a float between 0.0 and 2.0
let request = NvCreateChatCompletionRequest {
inner: CreateChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text("Hello".to_string()),
name: None,
},
)],
temperature: Some(-3.0),
..Default::default()
},
common: Default::default(),
nvext: None,
};
let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err());
if let Err((status, error_response)) = result {
assert_eq!(status, StatusCode::BAD_REQUEST);
assert_eq!(
error_response.error,
"Temperature must be between 0 and 2, got -3"
);
}
// Top P: Should be a float between 0.0 and 1.0
let request = NvCreateChatCompletionRequest {
inner: CreateChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text("Hello".to_string()),
name: None,
},
)],
top_p: Some(-3.0),
..Default::default()
},
common: Default::default(),
nvext: None,
};
let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err());
if let Err((status, error_response)) = result {
assert_eq!(status, StatusCode::BAD_REQUEST);
assert_eq!(
error_response.error,
"Top_p must be between 0 and 1, got -3"
);
}
// Repetition Penalty: Should be a float between 0.0 and 2.0
let request = NvCreateChatCompletionRequest {
inner: CreateChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text("Hello".to_string()),
name: None,
},
)],
..Default::default()
},
common: CommonExt::builder()
.repetition_penalty(-3.0)
.build()
.unwrap(),
nvext: None,
};
let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err());
if let Err((status, error_response)) = result {
assert_eq!(status, StatusCode::BAD_REQUEST);
assert_eq!(
error_response.error,
"Repetition penalty must be between 0 and 2, got -3"
);
}
// Top Logprobs: Should be a positive integer between 0 and 20
let request = NvCreateChatCompletionRequest {
inner: CreateChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text("Hello".to_string()),
name: None,
},
)],
top_logprobs: Some(25),
..Default::default()
},
common: Default::default(),
nvext: None,
};
let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err());
if let Err((status, error_response)) = result {
assert_eq!(status, StatusCode::BAD_REQUEST);
assert_eq!(
error_response.error,
"Top_logprobs must be between 0 and 20, got 25"
);
}
}
}
......@@ -332,6 +332,8 @@ impl ValidateRequest for NvCreateChatCompletionRequest {
validate::validate_user(self.inner.user.as_deref())?;
// none for function call
// none for functions
// Common Ext
validate::validate_repetition_penalty(self.get_repetition_penalty())?;
Ok(())
}
......
......@@ -425,6 +425,9 @@ impl ValidateRequest for NvCreateCompletionRequest {
validate::validate_user(self.inner.user.as_deref())?;
// none for seed
// Common Ext
validate::validate_repetition_penalty(self.get_repetition_penalty())?;
Ok(())
}
}
......@@ -86,6 +86,10 @@ pub const MAX_METADATA_VALUE_LENGTH: usize = 512;
pub const MAX_FUNCTION_NAME_LENGTH: usize = 64;
/// Maximum allowed value for Prompt IntegerArray elements
pub const MAX_PROMPT_TOKEN_ID: u32 = 50256;
/// Minimum allowed value for `repetition_penalty`
pub const MIN_REPETITION_PENALTY: f32 = 0.0;
/// Maximum allowed value for `repetition_penalty`
pub const MAX_REPETITION_PENALTY: f32 = 2.0;
//
// Shared Fields
......@@ -164,6 +168,20 @@ pub fn validate_presence_penalty(presence_penalty: Option<f32>) -> Result<(), an
Ok(())
}
pub fn validate_repetition_penalty(repetition_penalty: Option<f32>) -> Result<(), anyhow::Error> {
if let Some(penalty) = repetition_penalty
&& !(MIN_REPETITION_PENALTY..=MAX_REPETITION_PENALTY).contains(&penalty)
{
anyhow::bail!(
"Repetition penalty must be between {} and {}, got {}",
MIN_REPETITION_PENALTY,
MAX_REPETITION_PENALTY,
penalty
);
}
Ok(())
}
/// Validates logit bias map
pub fn validate_logit_bias(
logit_bias: &Option<std::collections::HashMap<String, serde_json::Value>>,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment