"docs/api/vscode:/vscode.git/clone" did not exist on "617d55c04e72be51d8b4cc4a90d8b136db27da3b"
Unverified Commit 3d036fc4 authored by KrishnanPrash's avatar KrishnanPrash Committed by GitHub
Browse files

feat: Reject unsupported parameters with 400 Bad Request (`/v1/completions`) (#4140)


Signed-off-by: default avatarKrishnan Prashanth <kprashanth@nvidia.com>
parent 6e2b22ea
......@@ -1481,12 +1481,14 @@ dependencies = [
"galil-seiferas",
"hf-hub",
"humantime",
"image",
"itertools 0.14.0",
"json-five",
"minijinja",
"minijinja-contrib",
"modelexpress-client",
"modelexpress-common",
"ndarray",
"offset-allocator",
"oneshot",
"parking_lot",
......@@ -1506,6 +1508,7 @@ dependencies = [
"tmq",
"tokenizers",
"tokio",
"tokio-rayon",
"tokio-stream",
"tokio-util",
"toktrie",
......@@ -3302,6 +3305,16 @@ version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3"
[[package]]
name = "matrixmultiply"
version = "0.3.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08"
dependencies = [
"autocfg",
"rawpointer",
]
[[package]]
name = "maybe-rayon"
version = "0.1.1"
......@@ -3511,6 +3524,21 @@ version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d87ecb2933e8aeadb3e3a02b828fed80a7528047e68b4f424523a0981a3a084"
[[package]]
name = "ndarray"
version = "0.16.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841"
dependencies = [
"matrixmultiply",
"num-complex",
"num-integer",
"num-traits",
"portable-atomic",
"portable-atomic-util",
"rawpointer",
]
[[package]]
name = "neli"
version = "0.6.5"
......@@ -4770,6 +4798,12 @@ dependencies = [
"bitflags 2.9.3",
]
[[package]]
name = "rawpointer"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
[[package]]
name = "rayon"
version = "1.11.0"
......
......@@ -315,6 +315,7 @@ impl TryFrom<inference::ModelInferRequest> for NvCreateCompletionRequest {
common: Default::default(),
nvext: None,
metadata: None,
unsupported_fields: Default::default(),
})
}
}
......
......@@ -1469,6 +1469,7 @@ mod tests {
common: Default::default(),
nvext: None,
metadata: None,
unsupported_fields: Default::default(),
};
let result = validate_completion_fields_generic(&request);
......@@ -1492,6 +1493,7 @@ mod tests {
common: Default::default(),
nvext: None,
metadata: None,
unsupported_fields: Default::default(),
};
let result = validate_completion_fields_generic(&request);
assert!(result.is_err());
......@@ -1514,6 +1516,7 @@ mod tests {
common: Default::default(),
nvext: None,
metadata: None,
unsupported_fields: Default::default(),
};
let result = validate_completion_fields_generic(&request);
assert!(result.is_err());
......@@ -1536,6 +1539,7 @@ mod tests {
common: Default::default(),
nvext: None,
metadata: None,
unsupported_fields: Default::default(),
};
let result = validate_completion_fields_generic(&request);
assert!(result.is_err());
......@@ -1560,6 +1564,7 @@ mod tests {
.unwrap(),
nvext: None,
metadata: None,
unsupported_fields: Default::default(),
};
let result = validate_completion_fields_generic(&request);
assert!(result.is_err());
......@@ -1582,6 +1587,7 @@ mod tests {
common: Default::default(),
nvext: None,
metadata: None,
unsupported_fields: Default::default(),
};
let result = validate_completion_fields_generic(&request);
assert!(result.is_err());
......@@ -1612,6 +1618,7 @@ mod tests {
"session": {"id": "session-1", "timestamp": 1640995200}
})
.into(),
unsupported_fields: Default::default(),
};
let result = validate_completion_fields_generic(&request);
......@@ -1797,8 +1804,8 @@ mod tests {
}
#[test]
fn test_unknown_fields_rejected() {
// Test that all known unsupported fields are rejected and all shown in error message
fn test_chat_completions_unknown_fields_rejected() {
// Test that known unsupported fields are rejected and all shown in error message
let json = r#"{
"messages": [{"role": "user", "content": "Hello"}],
"model": "test-model",
......@@ -1837,4 +1844,36 @@ mod tests {
assert!(msg.contains("chat_template_kwargs"));
}
}
#[test]
fn test_completions_unsupported_fields_rejected() {
// Test that known unsupported fields are rejected and all shown in error message
let json = r#"{
"model": "test-model",
"prompt": "Hello",
"add_special_tokens": true,
"response_format": {"type": "json_object"}
}"#;
let request: NvCreateCompletionRequest = serde_json::from_str(json).unwrap();
// Verify both unsupported fields were captured
assert!(
request
.unsupported_fields
.contains_key("add_special_tokens")
);
assert!(request.unsupported_fields.contains_key("response_format"));
let result = validate_completion_fields_generic(&request);
assert!(result.is_err());
if let Err(error_response) = result {
assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
let msg = &error_response.1.message;
assert!(msg.contains("Unsupported parameter"));
// Verify both fields appear in error message
assert!(msg.contains("add_special_tokens"));
assert!(msg.contains("response_format"));
}
}
}
......@@ -37,6 +37,10 @@ pub struct NvCreateCompletionRequest {
// metadata - passthrough parameter without restrictions
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<serde_json::Value>,
/// Catch-all for unsupported fields - checked during validation
#[serde(flatten, default, skip_serializing)]
pub unsupported_fields: std::collections::HashMap<String, serde_json::Value>,
}
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
......@@ -372,6 +376,7 @@ impl OpenAIOutputOptionsProvider for NvCreateCompletionRequest {
/// allowing us to validate the data.
impl ValidateRequest for NvCreateCompletionRequest {
fn validate(&self) -> Result<(), anyhow::Error> {
validate::validate_no_unsupported_fields(&self.unsupported_fields)?;
validate::validate_model(&self.inner.model)?;
validate::validate_prompt(&self.inner.prompt)?;
validate::validate_suffix(self.inner.suffix.as_deref())?;
......
......@@ -29,6 +29,7 @@ impl CompletionSample {
common: Default::default(),
nvext: None,
metadata: None,
unsupported_fields: Default::default(),
};
Ok(Self {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment