Unverified Commit b4c8d948 authored by KrishnanPrash's avatar KrishnanPrash Committed by GitHub
Browse files

feat: echo parameter validation for `/v1/completions` (#3813)


Signed-off-by: default avatarKrishnan Prashanth <kprashanth@nvidia.com>
parent 6739154d
...@@ -18,6 +18,77 @@ use crate::error::OpenAIError; ...@@ -18,6 +18,77 @@ use crate::error::OpenAIError;
use super::{ChatCompletionStreamOptions, Choice, CompletionUsage, Prompt, Stop}; use super::{ChatCompletionStreamOptions, Choice, CompletionUsage, Prompt, Stop};
/// Custom deserializer for the echo parameter that only accepts booleans.
/// Rejects integers and strings with clear error messages.
fn deserialize_echo_bool<'de, D>(deserializer: D) -> Result<Option<bool>, D::Error>
where
D: serde::Deserializer<'de>,
{
// Outer visitor: handles Option semantics (Some/None/null)
struct StrictBoolVisitor;
impl<'de> serde::de::Visitor<'de> for StrictBoolVisitor {
type Value = Option<bool>;
// Required by Visitor trait
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("echo parameter to be a boolean (true or false) or null")
}
fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_any(BoolOnlyVisitor)
}
fn visit_none<E>(self) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(None)
}
fn visit_unit<E>(self) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(None)
}
}
// Inner visitor: validates type is boolean, rejects integers and strings
struct BoolOnlyVisitor;
impl<'de> serde::de::Visitor<'de> for BoolOnlyVisitor {
type Value = Option<bool>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("echo parameter to be a boolean (true or false) or null")
}
fn visit_bool<E>(self, value: bool) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(Some(value))
}
// Explicitly reject strings (including "null", "true", "false")
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Err(E::invalid_type(
serde::de::Unexpected::Str(value),
&"echo parameter to be a boolean (true or false) or null",
))
}
}
deserializer.deserialize_option(StrictBoolVisitor)
}
#[derive(Clone, Serialize, Deserialize, Default, Debug, Builder, PartialEq)] #[derive(Clone, Serialize, Deserialize, Default, Debug, Builder, PartialEq)]
#[builder(name = "CreateCompletionRequestArgs")] #[builder(name = "CreateCompletionRequestArgs")]
#[builder(pattern = "mutable")] #[builder(pattern = "mutable")]
...@@ -80,6 +151,7 @@ pub struct CreateCompletionRequest { ...@@ -80,6 +151,7 @@ pub struct CreateCompletionRequest {
/// Echo back the prompt in addition to the completion /// Echo back the prompt in addition to the completion
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
#[serde(default, deserialize_with = "deserialize_echo_bool")]
pub echo: Option<bool>, pub echo: Option<bool>,
/// Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence. /// Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence.
...@@ -149,3 +221,30 @@ pub struct CreateCompletionResponse { ...@@ -149,3 +221,30 @@ pub struct CreateCompletionResponse {
/// Parsed server side events stream until an \[DONE\] is received from server. /// Parsed server side events stream until an \[DONE\] is received from server.
pub type CompletionResponseStream = pub type CompletionResponseStream =
Pin<Box<dyn Stream<Item = Result<CreateCompletionResponse, OpenAIError>> + Send>>; Pin<Box<dyn Stream<Item = Result<CreateCompletionResponse, OpenAIError>> + Send>>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn echo_rejects_integer() {
let json = r#"{"model": "test_model", "prompt": "test", "echo": 1}"#;
let result: Result<CreateCompletionRequest, _> = serde_json::from_str(json);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("invalid type"));
assert!(err_msg.contains("integer"));
assert!(err_msg.contains("echo parameter"));
}
#[test]
fn echo_rejects_string() {
let json = r#"{"model": "test_model", "prompt": "test", "echo": "null"}"#;
let result: Result<CreateCompletionRequest, _> = serde_json::from_str(json);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("invalid type"));
assert!(err_msg.contains("string"));
assert!(err_msg.contains("echo parameter"));
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment