Unverified Commit c837b5ba authored by KrishnanPrash's avatar KrishnanPrash Committed by GitHub
Browse files

feat: Reject unsupported parameters with 400 Bad Request (#4021)


Signed-off-by: default avatarKrishnan Prashanth <kprashanth@nvidia.com>
parent f31a1dad
...@@ -229,6 +229,7 @@ async fn evaluate( ...@@ -229,6 +229,7 @@ async fn evaluate(
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None, chat_template_args: None,
unsupported_fields: Default::default(),
}; };
let mut stream = engine.generate(Context::new(req)).await?; let mut stream = engine.generate(Context::new(req)).await?;
let mut output = String::new(); let mut output = String::new();
......
...@@ -112,6 +112,7 @@ async fn main_loop( ...@@ -112,6 +112,7 @@ async fn main_loop(
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None, chat_template_args: None,
unsupported_fields: Default::default(),
}; };
// Call the model // Call the model
......
...@@ -1403,6 +1403,7 @@ mod tests { ...@@ -1403,6 +1403,7 @@ mod tests {
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None, chat_template_args: None,
unsupported_fields: Default::default(),
}; };
let result = validate_chat_completion_required_fields(&request); let result = validate_chat_completion_required_fields(&request);
assert!(result.is_err()); assert!(result.is_err());
...@@ -1431,6 +1432,7 @@ mod tests { ...@@ -1431,6 +1432,7 @@ mod tests {
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None, chat_template_args: None,
unsupported_fields: Default::default(),
}; };
let result = validate_chat_completion_required_fields(&request); let result = validate_chat_completion_required_fields(&request);
assert!(result.is_ok()); assert!(result.is_ok());
...@@ -1451,7 +1453,7 @@ mod tests { ...@@ -1451,7 +1453,7 @@ mod tests {
// 11. Repetition Penalty: Should be a float between 0.0 and 2.0 : Done // 11. Repetition Penalty: Should be a float between 0.0 and 2.0 : Done
// 12. Logprobs: Should be a positive integer between 0 and 5 : Done // 12. Logprobs: Should be a positive integer between 0 and 5 : Done
// invalid or non existing user : Only empty string is not allowed validation is there. How can we check non-extisting user ? // invalid or non existing user : Only empty string is not allowed validation is there. How can we check non-extisting user ?
// add_special_tokens null or invalid : Not Done // Unknown fields : Done (rejected via extra_fields catch-all)
// guided_whitespace_pattern null or invalid : Not Done // guided_whitespace_pattern null or invalid : Not Done
// "response_format": { "type": "invalid_format" } : Not Done // "response_format": { "type": "invalid_format" } : Not Done
// "logit_bias": { "invalid_token": "not_a_number" }, : Partial Validation is already there // "logit_bias": { "invalid_token": "not_a_number" }, : Partial Validation is already there
...@@ -1638,6 +1640,7 @@ mod tests { ...@@ -1638,6 +1640,7 @@ mod tests {
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None, chat_template_args: None,
unsupported_fields: Default::default(),
}; };
let result = validate_chat_completion_fields_generic(&request); let result = validate_chat_completion_fields_generic(&request);
...@@ -1666,6 +1669,7 @@ mod tests { ...@@ -1666,6 +1669,7 @@ mod tests {
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None, chat_template_args: None,
unsupported_fields: Default::default(),
}; };
let result = validate_chat_completion_fields_generic(&request); let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err()); assert!(result.is_err());
...@@ -1693,6 +1697,7 @@ mod tests { ...@@ -1693,6 +1697,7 @@ mod tests {
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None, chat_template_args: None,
unsupported_fields: Default::default(),
}; };
let result = validate_chat_completion_fields_generic(&request); let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err()); assert!(result.is_err());
...@@ -1720,6 +1725,7 @@ mod tests { ...@@ -1720,6 +1725,7 @@ mod tests {
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None, chat_template_args: None,
unsupported_fields: Default::default(),
}; };
let result = validate_chat_completion_fields_generic(&request); let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err()); assert!(result.is_err());
...@@ -1749,6 +1755,7 @@ mod tests { ...@@ -1749,6 +1755,7 @@ mod tests {
.unwrap(), .unwrap(),
nvext: None, nvext: None,
chat_template_args: None, chat_template_args: None,
unsupported_fields: Default::default(),
}; };
let result = validate_chat_completion_fields_generic(&request); let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err()); assert!(result.is_err());
...@@ -1776,6 +1783,7 @@ mod tests { ...@@ -1776,6 +1783,7 @@ mod tests {
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None, chat_template_args: None,
unsupported_fields: Default::default(),
}; };
let result = validate_chat_completion_fields_generic(&request); let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err()); assert!(result.is_err());
...@@ -1787,4 +1795,46 @@ mod tests { ...@@ -1787,4 +1795,46 @@ mod tests {
); );
} }
} }
#[test]
fn test_unknown_fields_rejected() {
// Test that all known unsupported fields are rejected and all shown in error message
let json = r#"{
"messages": [{"role": "user", "content": "Hello"}],
"model": "test-model",
"add_special_tokens": true,
"documents": ["doc1"],
"chat_template": "custom",
"chat_template_kwargs": {"key": "val"}
}"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json).unwrap();
// Verify all unsupported fields were captured
assert!(
request
.unsupported_fields
.contains_key("add_special_tokens")
);
assert!(request.unsupported_fields.contains_key("documents"));
assert!(request.unsupported_fields.contains_key("chat_template"));
assert!(
request
.unsupported_fields
.contains_key("chat_template_kwargs")
);
let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err());
if let Err(error_response) = result {
assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
let msg = &error_response.1.message;
assert!(msg.contains("Unsupported parameter"));
// Verify all fields appear in the error message
assert!(msg.contains("add_special_tokens"));
assert!(msg.contains("documents"));
assert!(msg.contains("chat_template"));
assert!(msg.contains("chat_template_kwargs"));
}
}
} }
...@@ -44,6 +44,10 @@ pub struct NvCreateChatCompletionRequest { ...@@ -44,6 +44,10 @@ pub struct NvCreateChatCompletionRequest {
/// Extra args to pass to the chat template rendering context /// Extra args to pass to the chat template rendering context
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub chat_template_args: Option<std::collections::HashMap<String, serde_json::Value>>, pub chat_template_args: Option<std::collections::HashMap<String, serde_json::Value>>,
/// Catch-all for unsupported fields - checked during validation
#[serde(flatten, default, skip_serializing)]
pub unsupported_fields: std::collections::HashMap<String, serde_json::Value>,
} }
/// A response structure for unary chat completion responses, embedding OpenAI's /// A response structure for unary chat completion responses, embedding OpenAI's
...@@ -271,6 +275,7 @@ impl OpenAIOutputOptionsProvider for NvCreateChatCompletionRequest { ...@@ -271,6 +275,7 @@ impl OpenAIOutputOptionsProvider for NvCreateChatCompletionRequest {
/// allowing us to validate the data. /// allowing us to validate the data.
impl ValidateRequest for NvCreateChatCompletionRequest { impl ValidateRequest for NvCreateChatCompletionRequest {
fn validate(&self) -> Result<(), anyhow::Error> { fn validate(&self) -> Result<(), anyhow::Error> {
validate::validate_no_unsupported_fields(&self.unsupported_fields)?;
validate::validate_messages(&self.inner.messages)?; validate::validate_messages(&self.inner.messages)?;
validate::validate_model(&self.inner.model)?; validate::validate_model(&self.inner.model)?;
// none for store // none for store
......
...@@ -393,6 +393,7 @@ mod tests { ...@@ -393,6 +393,7 @@ mod tests {
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None, chat_template_args: None,
unsupported_fields: Default::default(),
} }
} }
......
...@@ -189,6 +189,7 @@ impl TryFrom<NvCreateResponse> for NvCreateChatCompletionRequest { ...@@ -189,6 +189,7 @@ impl TryFrom<NvCreateResponse> for NvCreateChatCompletionRequest {
common: Default::default(), common: Default::default(),
nvext: resp.nvext, nvext: resp.nvext,
chat_template_args: None, chat_template_args: None,
unsupported_fields: Default::default(),
}) })
} }
} }
......
...@@ -96,6 +96,20 @@ pub const MAX_REPETITION_PENALTY: f32 = 2.0; ...@@ -96,6 +96,20 @@ pub const MAX_REPETITION_PENALTY: f32 = 2.0;
// Shared Fields // Shared Fields
// //
/// Validates that no unsupported fields are present in the request
pub fn validate_no_unsupported_fields(
unsupported_fields: &std::collections::HashMap<String, serde_json::Value>,
) -> Result<(), anyhow::Error> {
if !unsupported_fields.is_empty() {
let fields: Vec<_> = unsupported_fields
.keys()
.map(|s| format!("`{}`", s))
.collect();
anyhow::bail!("Unsupported parameter(s): {}", fields.join(", "));
}
Ok(())
}
/// Validates the temperature parameter /// Validates the temperature parameter
pub fn validate_temperature(temperature: Option<f32>) -> Result<(), anyhow::Error> { pub fn validate_temperature(temperature: Option<f32>) -> Result<(), anyhow::Error> {
if let Some(temp) = temperature if let Some(temp) = temperature
......
...@@ -770,6 +770,7 @@ async fn test_nv_custom_client() { ...@@ -770,6 +770,7 @@ async fn test_nv_custom_client() {
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None, chat_template_args: None,
unsupported_fields: Default::default(),
}; };
let result = nv_custom_client.chat_stream(request).await; let result = nv_custom_client.chat_stream(request).await;
...@@ -810,6 +811,7 @@ async fn test_nv_custom_client() { ...@@ -810,6 +811,7 @@ async fn test_nv_custom_client() {
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None, chat_template_args: None,
unsupported_fields: Default::default(),
}; };
let result = nv_custom_client.chat_stream(request).await; let result = nv_custom_client.chat_stream(request).await;
...@@ -851,6 +853,7 @@ async fn test_nv_custom_client() { ...@@ -851,6 +853,7 @@ async fn test_nv_custom_client() {
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None, chat_template_args: None,
unsupported_fields: Default::default(),
}; };
let result = nv_custom_client let result = nv_custom_client
......
...@@ -91,6 +91,7 @@ fn create_mock_chat_completion_request() -> NvCreateChatCompletionRequest { ...@@ -91,6 +91,7 @@ fn create_mock_chat_completion_request() -> NvCreateChatCompletionRequest {
common: CommonExt::default(), common: CommonExt::default(),
nvext: None, nvext: None,
chat_template_args: None, chat_template_args: None,
unsupported_fields: Default::default(),
} }
} }
......
...@@ -272,6 +272,7 @@ impl Request { ...@@ -272,6 +272,7 @@ impl Request {
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None, chat_template_args: None,
unsupported_fields: Default::default(),
} }
} }
} }
......
...@@ -68,6 +68,7 @@ fn test_sampling_parameters_include_stop_str_in_output_extraction() { ...@@ -68,6 +68,7 @@ fn test_sampling_parameters_include_stop_str_in_output_extraction() {
.unwrap(), .unwrap(),
nvext: None, nvext: None,
chat_template_args: None, chat_template_args: None,
unsupported_fields: Default::default(),
}; };
let sampling = request.extract_sampling_options().unwrap(); let sampling = request.extract_sampling_options().unwrap();
...@@ -296,6 +297,7 @@ fn test_serialization_preserves_structure() { ...@@ -296,6 +297,7 @@ fn test_serialization_preserves_structure() {
..Default::default() ..Default::default()
}), }),
chat_template_args: None, chat_template_args: None,
unsupported_fields: Default::default(),
}; };
let json = serde_json::to_value(&request).unwrap(); let json = serde_json::to_value(&request).unwrap();
...@@ -346,6 +348,7 @@ fn test_sampling_parameters_extraction() { ...@@ -346,6 +348,7 @@ fn test_sampling_parameters_extraction() {
.unwrap(), .unwrap(),
nvext: None, nvext: None,
chat_template_args: None, chat_template_args: None,
unsupported_fields: Default::default(),
}; };
let sampling_options = request.extract_sampling_options().unwrap(); let sampling_options = request.extract_sampling_options().unwrap();
......
...@@ -150,6 +150,7 @@ fn create_chat_request(include_usage: Option<bool>) -> NvCreateChatCompletionReq ...@@ -150,6 +150,7 @@ fn create_chat_request(include_usage: Option<bool>) -> NvCreateChatCompletionReq
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None, chat_template_args: None,
unsupported_fields: Default::default(),
} }
} }
...@@ -329,6 +330,7 @@ fn create_nonstreaming_chat_request() -> NvCreateChatCompletionRequest { ...@@ -329,6 +330,7 @@ fn create_nonstreaming_chat_request() -> NvCreateChatCompletionRequest {
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
chat_template_args: None, chat_template_args: None,
unsupported_fields: Default::default(),
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment