Unverified Commit c837b5ba authored by KrishnanPrash's avatar KrishnanPrash Committed by GitHub
Browse files

feat: Reject unsupported parameters with 400 Bad Request (#4021)


Signed-off-by: default avatarKrishnan Prashanth <kprashanth@nvidia.com>
parent f31a1dad
......@@ -229,6 +229,7 @@ async fn evaluate(
common: Default::default(),
nvext: None,
chat_template_args: None,
unsupported_fields: Default::default(),
};
let mut stream = engine.generate(Context::new(req)).await?;
let mut output = String::new();
......
......@@ -112,6 +112,7 @@ async fn main_loop(
common: Default::default(),
nvext: None,
chat_template_args: None,
unsupported_fields: Default::default(),
};
// Call the model
......
......@@ -1403,6 +1403,7 @@ mod tests {
common: Default::default(),
nvext: None,
chat_template_args: None,
unsupported_fields: Default::default(),
};
let result = validate_chat_completion_required_fields(&request);
assert!(result.is_err());
......@@ -1431,6 +1432,7 @@ mod tests {
common: Default::default(),
nvext: None,
chat_template_args: None,
unsupported_fields: Default::default(),
};
let result = validate_chat_completion_required_fields(&request);
assert!(result.is_ok());
......@@ -1451,7 +1453,7 @@ mod tests {
// 11. Repetition Penalty: Should be a float between 0.0 and 2.0 : Done
// 12. Logprobs: Should be a positive integer between 0 and 5 : Done
// invalid or non existing user : Only empty string is not allowed validation is there. How can we check non-extisting user ?
// add_special_tokens null or invalid : Not Done
// Unknown fields : Done (rejected via extra_fields catch-all)
// guided_whitespace_pattern null or invalid : Not Done
// "response_format": { "type": "invalid_format" } : Not Done
// "logit_bias": { "invalid_token": "not_a_number" }, : Partial Validation is already there
......@@ -1638,6 +1640,7 @@ mod tests {
common: Default::default(),
nvext: None,
chat_template_args: None,
unsupported_fields: Default::default(),
};
let result = validate_chat_completion_fields_generic(&request);
......@@ -1666,6 +1669,7 @@ mod tests {
common: Default::default(),
nvext: None,
chat_template_args: None,
unsupported_fields: Default::default(),
};
let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err());
......@@ -1693,6 +1697,7 @@ mod tests {
common: Default::default(),
nvext: None,
chat_template_args: None,
unsupported_fields: Default::default(),
};
let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err());
......@@ -1720,6 +1725,7 @@ mod tests {
common: Default::default(),
nvext: None,
chat_template_args: None,
unsupported_fields: Default::default(),
};
let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err());
......@@ -1749,6 +1755,7 @@ mod tests {
.unwrap(),
nvext: None,
chat_template_args: None,
unsupported_fields: Default::default(),
};
let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err());
......@@ -1776,6 +1783,7 @@ mod tests {
common: Default::default(),
nvext: None,
chat_template_args: None,
unsupported_fields: Default::default(),
};
let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err());
......@@ -1787,4 +1795,46 @@ mod tests {
);
}
}
#[test]
fn test_unknown_fields_rejected() {
// Test that all known unsupported fields are rejected and all shown in error message
let json = r#"{
"messages": [{"role": "user", "content": "Hello"}],
"model": "test-model",
"add_special_tokens": true,
"documents": ["doc1"],
"chat_template": "custom",
"chat_template_kwargs": {"key": "val"}
}"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json).unwrap();
// Verify all unsupported fields were captured
assert!(
request
.unsupported_fields
.contains_key("add_special_tokens")
);
assert!(request.unsupported_fields.contains_key("documents"));
assert!(request.unsupported_fields.contains_key("chat_template"));
assert!(
request
.unsupported_fields
.contains_key("chat_template_kwargs")
);
let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err());
if let Err(error_response) = result {
assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
let msg = &error_response.1.message;
assert!(msg.contains("Unsupported parameter"));
// Verify all fields appear in the error message
assert!(msg.contains("add_special_tokens"));
assert!(msg.contains("documents"));
assert!(msg.contains("chat_template"));
assert!(msg.contains("chat_template_kwargs"));
}
}
}
......@@ -44,6 +44,10 @@ pub struct NvCreateChatCompletionRequest {
/// Extra args to pass to the chat template rendering context
#[serde(default, skip_serializing_if = "Option::is_none")]
pub chat_template_args: Option<std::collections::HashMap<String, serde_json::Value>>,
/// Catch-all for unsupported fields - checked during validation
#[serde(flatten, default, skip_serializing)]
pub unsupported_fields: std::collections::HashMap<String, serde_json::Value>,
}
/// A response structure for unary chat completion responses, embedding OpenAI's
......@@ -271,6 +275,7 @@ impl OpenAIOutputOptionsProvider for NvCreateChatCompletionRequest {
/// allowing us to validate the data.
impl ValidateRequest for NvCreateChatCompletionRequest {
fn validate(&self) -> Result<(), anyhow::Error> {
validate::validate_no_unsupported_fields(&self.unsupported_fields)?;
validate::validate_messages(&self.inner.messages)?;
validate::validate_model(&self.inner.model)?;
// none for store
......
......@@ -393,6 +393,7 @@ mod tests {
common: Default::default(),
nvext: None,
chat_template_args: None,
unsupported_fields: Default::default(),
}
}
......
......@@ -189,6 +189,7 @@ impl TryFrom<NvCreateResponse> for NvCreateChatCompletionRequest {
common: Default::default(),
nvext: resp.nvext,
chat_template_args: None,
unsupported_fields: Default::default(),
})
}
}
......
......@@ -96,6 +96,20 @@ pub const MAX_REPETITION_PENALTY: f32 = 2.0;
// Shared Fields
//
/// Validates that no unsupported fields are present in the request
pub fn validate_no_unsupported_fields(
unsupported_fields: &std::collections::HashMap<String, serde_json::Value>,
) -> Result<(), anyhow::Error> {
if !unsupported_fields.is_empty() {
let fields: Vec<_> = unsupported_fields
.keys()
.map(|s| format!("`{}`", s))
.collect();
anyhow::bail!("Unsupported parameter(s): {}", fields.join(", "));
}
Ok(())
}
/// Validates the temperature parameter
pub fn validate_temperature(temperature: Option<f32>) -> Result<(), anyhow::Error> {
if let Some(temp) = temperature
......
......@@ -770,6 +770,7 @@ async fn test_nv_custom_client() {
common: Default::default(),
nvext: None,
chat_template_args: None,
unsupported_fields: Default::default(),
};
let result = nv_custom_client.chat_stream(request).await;
......@@ -810,6 +811,7 @@ async fn test_nv_custom_client() {
common: Default::default(),
nvext: None,
chat_template_args: None,
unsupported_fields: Default::default(),
};
let result = nv_custom_client.chat_stream(request).await;
......@@ -851,6 +853,7 @@ async fn test_nv_custom_client() {
common: Default::default(),
nvext: None,
chat_template_args: None,
unsupported_fields: Default::default(),
};
let result = nv_custom_client
......
......@@ -91,6 +91,7 @@ fn create_mock_chat_completion_request() -> NvCreateChatCompletionRequest {
common: CommonExt::default(),
nvext: None,
chat_template_args: None,
unsupported_fields: Default::default(),
}
}
......
......@@ -272,6 +272,7 @@ impl Request {
common: Default::default(),
nvext: None,
chat_template_args: None,
unsupported_fields: Default::default(),
}
}
}
......
......@@ -68,6 +68,7 @@ fn test_sampling_parameters_include_stop_str_in_output_extraction() {
.unwrap(),
nvext: None,
chat_template_args: None,
unsupported_fields: Default::default(),
};
let sampling = request.extract_sampling_options().unwrap();
......@@ -296,6 +297,7 @@ fn test_serialization_preserves_structure() {
..Default::default()
}),
chat_template_args: None,
unsupported_fields: Default::default(),
};
let json = serde_json::to_value(&request).unwrap();
......@@ -346,6 +348,7 @@ fn test_sampling_parameters_extraction() {
.unwrap(),
nvext: None,
chat_template_args: None,
unsupported_fields: Default::default(),
};
let sampling_options = request.extract_sampling_options().unwrap();
......
......@@ -150,6 +150,7 @@ fn create_chat_request(include_usage: Option<bool>) -> NvCreateChatCompletionReq
common: Default::default(),
nvext: None,
chat_template_args: None,
unsupported_fields: Default::default(),
}
}
......@@ -329,6 +330,7 @@ fn create_nonstreaming_chat_request() -> NvCreateChatCompletionRequest {
common: Default::default(),
nvext: None,
chat_template_args: None,
unsupported_fields: Default::default(),
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment