Unverified Commit 1f6b83be authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

chore: added include_stop_str_in_output (#2782)

parent bb9566b7
......@@ -337,6 +337,9 @@ pub struct SamplingOptions {
/// The seed to use when sampling
pub seed: Option<i64>,
/// Whether to include the stop string in the output.
pub include_stop_str_in_output: Option<bool>,
/// Guided Decoding Options
pub guided_decoding: Option<GuidedDecodingOptions>,
}
......
......@@ -97,6 +97,7 @@ impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvid
.map_err(|e| anyhow::anyhow!("Error validating presence_penalty: {}", e))?;
let top_k = CommonExtProvider::get_top_k(self);
let repetition_penalty = CommonExtProvider::get_repetition_penalty(self);
let include_stop_str_in_output = CommonExtProvider::get_include_stop_str_in_output(self);
if let Some(nvext) = self.nvext() {
let greedy = nvext.greed_sampling.unwrap_or(false);
......@@ -141,6 +142,7 @@ impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvid
use_beam_search: None,
length_penalty: None,
guided_decoding,
include_stop_str_in_output,
})
}
}
......
......@@ -216,6 +216,10 @@ impl CommonExtProvider for NvCreateChatCompletionRequest {
.and_then(|nv| nv.repetition_penalty.as_ref()),
)
}
fn get_include_stop_str_in_output(&self) -> Option<bool> {
self.common.include_stop_str_in_output
}
}
/// Implements `OpenAIStopConditionsProvider` for `NvCreateChatCompletionRequest`,
......
......@@ -35,6 +35,11 @@ pub struct CommonExt {
#[validate(range(exclusive_min = 0.0, max = 2.0))]
pub repetition_penalty: Option<f32>,
/// include_stop_str_in_output
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub include_stop_str_in_output: Option<bool>,
/// Guided Decoding Options
/// If specified, the output will be a JSON object. Can be a string, an object, or null.
#[serde(default, skip_serializing_if = "Option::is_none")]
......@@ -83,6 +88,7 @@ pub trait CommonExtProvider {
/// Other sampling Options
fn get_top_k(&self) -> Option<i32>;
fn get_repetition_penalty(&self) -> Option<f32>;
fn get_include_stop_str_in_output(&self) -> Option<bool>;
}
/// Helper function to emit deprecation warnings for nvext parameters
......@@ -132,6 +138,7 @@ mod tests {
assert_eq!(common_ext.guided_grammar, None);
assert_eq!(common_ext.guided_choice, None);
assert_eq!(common_ext.guided_decoding_backend, None);
assert_eq!(common_ext.include_stop_str_in_output, None);
}
#[test]
......@@ -141,6 +148,7 @@ mod tests {
.min_tokens(10)
.top_k(50)
.repetition_penalty(1.2)
.include_stop_str_in_output(true)
.guided_json(serde_json::json!({"key": "value"}))
.guided_regex("regex".to_string())
.guided_grammar("grammar".to_string())
......@@ -153,6 +161,7 @@ mod tests {
assert_eq!(common_ext.min_tokens, Some(10));
assert_eq!(common_ext.top_k, Some(50));
assert_eq!(common_ext.repetition_penalty, Some(1.2));
assert_eq!(common_ext.include_stop_str_in_output, Some(true));
assert_eq!(
common_ext.guided_json.as_ref(),
Some(&serde_json::json!({"key": "value"}))
......@@ -175,11 +184,13 @@ mod tests {
let common_ext = CommonExt::builder()
.ignore_eos(false)
.min_tokens(5)
.include_stop_str_in_output(true)
.build()
.unwrap();
assert_eq!(common_ext.ignore_eos, Some(false));
assert_eq!(common_ext.min_tokens, Some(5));
assert_eq!(common_ext.include_stop_str_in_output, Some(true));
}
#[test]
......@@ -190,6 +201,7 @@ mod tests {
min_tokens: Some(0), // Should be valid (min = 0)
top_k: None,
repetition_penalty: None,
include_stop_str_in_output: None,
guided_json: None,
guided_regex: None,
guided_grammar: None,
......@@ -208,6 +220,7 @@ mod tests {
assert_eq!(common_ext.min_tokens, None);
assert_eq!(common_ext.top_k, None);
assert_eq!(common_ext.repetition_penalty, None);
assert_eq!(common_ext.include_stop_str_in_output, None);
assert!(common_ext.validate().is_ok());
}
......@@ -220,6 +233,7 @@ mod tests {
assert_eq!(common_ext.min_tokens, None);
assert_eq!(common_ext.top_k, None);
assert_eq!(common_ext.repetition_penalty, None);
assert_eq!(common_ext.include_stop_str_in_output, None);
assert!(common_ext.validate().is_ok());
}
......
......@@ -210,6 +210,10 @@ impl CommonExtProvider for NvCreateCompletionRequest {
.and_then(|nv| nv.repetition_penalty.as_ref()),
)
}
fn get_include_stop_str_in_output(&self) -> Option<bool> {
self.common.include_stop_str_in_output
}
}
impl OpenAIStopConditionsProvider for NvCreateCompletionRequest {
......
......@@ -25,6 +25,52 @@ fn test_chat_completions_ignore_eos_from_common() {
assert_eq!(request.common.ignore_eos, Some(true));
assert_eq!(request.common.min_tokens, Some(100));
assert_eq!(request.common.include_stop_str_in_output, None);
}
#[test]
fn test_chat_completions_include_stop_str_in_output_from_common() {
let json_str = r#"{
"model": "test-model",
"messages": [{"role": "user", "content": "Hello"}],
"include_stop_str_in_output": true
}"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
assert_eq!(request.common.include_stop_str_in_output, Some(true));
assert_eq!(request.get_include_stop_str_in_output(), Some(true));
}
#[test]
fn test_completions_include_stop_str_in_output_from_common() {
let json_str = r#"{
"model": "test-model",
"prompt": "Hello world",
"include_stop_str_in_output": true
}"#;
let request: NvCreateCompletionRequest = serde_json::from_str(json_str).unwrap();
assert_eq!(request.common.include_stop_str_in_output, Some(true));
// When exposed on completions, this should also be available via the provider
assert_eq!(request.get_include_stop_str_in_output(), Some(true));
}
#[test]
fn test_sampling_parameters_include_stop_str_in_output_extraction() {
use dynamo_llm::protocols::common::SamplingOptionsProvider;
let request = NvCreateChatCompletionRequest {
inner: Default::default(),
common: CommonExt::builder()
.include_stop_str_in_output(true)
.build()
.unwrap(),
nvext: None,
};
let sampling = request.extract_sampling_options().unwrap();
assert_eq!(sampling.include_stop_str_in_output, Some(true));
}
#[test]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment