Unverified Commit 3d036fc4 authored by KrishnanPrash's avatar KrishnanPrash Committed by GitHub
Browse files

feat: Reject unsupported parameters with 400 Bad Request (`/v1/completions`) (#4140)


Signed-off-by: default avatarKrishnan Prashanth <kprashanth@nvidia.com>
parent 6e2b22ea
...@@ -1481,12 +1481,14 @@ dependencies = [ ...@@ -1481,12 +1481,14 @@ dependencies = [
"galil-seiferas", "galil-seiferas",
"hf-hub", "hf-hub",
"humantime", "humantime",
"image",
"itertools 0.14.0", "itertools 0.14.0",
"json-five", "json-five",
"minijinja", "minijinja",
"minijinja-contrib", "minijinja-contrib",
"modelexpress-client", "modelexpress-client",
"modelexpress-common", "modelexpress-common",
"ndarray",
"offset-allocator", "offset-allocator",
"oneshot", "oneshot",
"parking_lot", "parking_lot",
...@@ -1506,6 +1508,7 @@ dependencies = [ ...@@ -1506,6 +1508,7 @@ dependencies = [
"tmq", "tmq",
"tokenizers", "tokenizers",
"tokio", "tokio",
"tokio-rayon",
"tokio-stream", "tokio-stream",
"tokio-util", "tokio-util",
"toktrie", "toktrie",
...@@ -3302,6 +3305,16 @@ version = "0.8.4" ...@@ -3302,6 +3305,16 @@ version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3"
[[package]]
name = "matrixmultiply"
version = "0.3.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08"
dependencies = [
"autocfg",
"rawpointer",
]
[[package]] [[package]]
name = "maybe-rayon" name = "maybe-rayon"
version = "0.1.1" version = "0.1.1"
...@@ -3511,6 +3524,21 @@ version = "0.10.1" ...@@ -3511,6 +3524,21 @@ version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d87ecb2933e8aeadb3e3a02b828fed80a7528047e68b4f424523a0981a3a084" checksum = "1d87ecb2933e8aeadb3e3a02b828fed80a7528047e68b4f424523a0981a3a084"
[[package]]
name = "ndarray"
version = "0.16.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841"
dependencies = [
"matrixmultiply",
"num-complex",
"num-integer",
"num-traits",
"portable-atomic",
"portable-atomic-util",
"rawpointer",
]
[[package]] [[package]]
name = "neli" name = "neli"
version = "0.6.5" version = "0.6.5"
...@@ -4770,6 +4798,12 @@ dependencies = [ ...@@ -4770,6 +4798,12 @@ dependencies = [
"bitflags 2.9.3", "bitflags 2.9.3",
] ]
[[package]]
name = "rawpointer"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
[[package]] [[package]]
name = "rayon" name = "rayon"
version = "1.11.0" version = "1.11.0"
......
...@@ -315,6 +315,7 @@ impl TryFrom<inference::ModelInferRequest> for NvCreateCompletionRequest { ...@@ -315,6 +315,7 @@ impl TryFrom<inference::ModelInferRequest> for NvCreateCompletionRequest {
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
metadata: None, metadata: None,
unsupported_fields: Default::default(),
}) })
} }
} }
......
...@@ -1469,6 +1469,7 @@ mod tests { ...@@ -1469,6 +1469,7 @@ mod tests {
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
metadata: None, metadata: None,
unsupported_fields: Default::default(),
}; };
let result = validate_completion_fields_generic(&request); let result = validate_completion_fields_generic(&request);
...@@ -1492,6 +1493,7 @@ mod tests { ...@@ -1492,6 +1493,7 @@ mod tests {
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
metadata: None, metadata: None,
unsupported_fields: Default::default(),
}; };
let result = validate_completion_fields_generic(&request); let result = validate_completion_fields_generic(&request);
assert!(result.is_err()); assert!(result.is_err());
...@@ -1514,6 +1516,7 @@ mod tests { ...@@ -1514,6 +1516,7 @@ mod tests {
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
metadata: None, metadata: None,
unsupported_fields: Default::default(),
}; };
let result = validate_completion_fields_generic(&request); let result = validate_completion_fields_generic(&request);
assert!(result.is_err()); assert!(result.is_err());
...@@ -1536,6 +1539,7 @@ mod tests { ...@@ -1536,6 +1539,7 @@ mod tests {
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
metadata: None, metadata: None,
unsupported_fields: Default::default(),
}; };
let result = validate_completion_fields_generic(&request); let result = validate_completion_fields_generic(&request);
assert!(result.is_err()); assert!(result.is_err());
...@@ -1560,6 +1564,7 @@ mod tests { ...@@ -1560,6 +1564,7 @@ mod tests {
.unwrap(), .unwrap(),
nvext: None, nvext: None,
metadata: None, metadata: None,
unsupported_fields: Default::default(),
}; };
let result = validate_completion_fields_generic(&request); let result = validate_completion_fields_generic(&request);
assert!(result.is_err()); assert!(result.is_err());
...@@ -1582,6 +1587,7 @@ mod tests { ...@@ -1582,6 +1587,7 @@ mod tests {
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
metadata: None, metadata: None,
unsupported_fields: Default::default(),
}; };
let result = validate_completion_fields_generic(&request); let result = validate_completion_fields_generic(&request);
assert!(result.is_err()); assert!(result.is_err());
...@@ -1612,6 +1618,7 @@ mod tests { ...@@ -1612,6 +1618,7 @@ mod tests {
"session": {"id": "session-1", "timestamp": 1640995200} "session": {"id": "session-1", "timestamp": 1640995200}
}) })
.into(), .into(),
unsupported_fields: Default::default(),
}; };
let result = validate_completion_fields_generic(&request); let result = validate_completion_fields_generic(&request);
...@@ -1797,8 +1804,8 @@ mod tests { ...@@ -1797,8 +1804,8 @@ mod tests {
} }
#[test] #[test]
fn test_unknown_fields_rejected() { fn test_chat_completions_unknown_fields_rejected() {
// Test that all known unsupported fields are rejected and all shown in error message // Test that known unsupported fields are rejected and all shown in error message
let json = r#"{ let json = r#"{
"messages": [{"role": "user", "content": "Hello"}], "messages": [{"role": "user", "content": "Hello"}],
"model": "test-model", "model": "test-model",
...@@ -1837,4 +1844,36 @@ mod tests { ...@@ -1837,4 +1844,36 @@ mod tests {
assert!(msg.contains("chat_template_kwargs")); assert!(msg.contains("chat_template_kwargs"));
} }
} }
#[test]
fn test_completions_unsupported_fields_rejected() {
// Test that known unsupported fields are rejected and all shown in error message
let json = r#"{
"model": "test-model",
"prompt": "Hello",
"add_special_tokens": true,
"response_format": {"type": "json_object"}
}"#;
let request: NvCreateCompletionRequest = serde_json::from_str(json).unwrap();
// Verify both unsupported fields were captured
assert!(
request
.unsupported_fields
.contains_key("add_special_tokens")
);
assert!(request.unsupported_fields.contains_key("response_format"));
let result = validate_completion_fields_generic(&request);
assert!(result.is_err());
if let Err(error_response) = result {
assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
let msg = &error_response.1.message;
assert!(msg.contains("Unsupported parameter"));
// Verify both fields appear in error message
assert!(msg.contains("add_special_tokens"));
assert!(msg.contains("response_format"));
}
}
} }
...@@ -37,6 +37,10 @@ pub struct NvCreateCompletionRequest { ...@@ -37,6 +37,10 @@ pub struct NvCreateCompletionRequest {
// metadata - passthrough parameter without restrictions // metadata - passthrough parameter without restrictions
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<serde_json::Value>, pub metadata: Option<serde_json::Value>,
/// Catch-all for unsupported fields - checked during validation
#[serde(flatten, default, skip_serializing)]
pub unsupported_fields: std::collections::HashMap<String, serde_json::Value>,
} }
#[derive(Serialize, Deserialize, Validate, Debug, Clone)] #[derive(Serialize, Deserialize, Validate, Debug, Clone)]
...@@ -372,6 +376,7 @@ impl OpenAIOutputOptionsProvider for NvCreateCompletionRequest { ...@@ -372,6 +376,7 @@ impl OpenAIOutputOptionsProvider for NvCreateCompletionRequest {
/// allowing us to validate the data. /// allowing us to validate the data.
impl ValidateRequest for NvCreateCompletionRequest { impl ValidateRequest for NvCreateCompletionRequest {
fn validate(&self) -> Result<(), anyhow::Error> { fn validate(&self) -> Result<(), anyhow::Error> {
validate::validate_no_unsupported_fields(&self.unsupported_fields)?;
validate::validate_model(&self.inner.model)?; validate::validate_model(&self.inner.model)?;
validate::validate_prompt(&self.inner.prompt)?; validate::validate_prompt(&self.inner.prompt)?;
validate::validate_suffix(self.inner.suffix.as_deref())?; validate::validate_suffix(self.inner.suffix.as_deref())?;
......
...@@ -29,6 +29,7 @@ impl CompletionSample { ...@@ -29,6 +29,7 @@ impl CompletionSample {
common: Default::default(), common: Default::default(),
nvext: None, nvext: None,
metadata: None, metadata: None,
unsupported_fields: Default::default(),
}; };
Ok(Self { Ok(Self {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment