Unverified Commit 1f6b83be authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

chore: added include_stop_str_in_output (#2782)

parent bb9566b7
...@@ -337,6 +337,9 @@ pub struct SamplingOptions { ...@@ -337,6 +337,9 @@ pub struct SamplingOptions {
/// The seed to use when sampling /// The seed to use when sampling
pub seed: Option<i64>, pub seed: Option<i64>,
/// Whether to include the stop string in the output.
pub include_stop_str_in_output: Option<bool>,
/// Guided Decoding Options /// Guided Decoding Options
pub guided_decoding: Option<GuidedDecodingOptions>, pub guided_decoding: Option<GuidedDecodingOptions>,
} }
......
...@@ -97,6 +97,7 @@ impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvid ...@@ -97,6 +97,7 @@ impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvid
.map_err(|e| anyhow::anyhow!("Error validating presence_penalty: {}", e))?; .map_err(|e| anyhow::anyhow!("Error validating presence_penalty: {}", e))?;
let top_k = CommonExtProvider::get_top_k(self); let top_k = CommonExtProvider::get_top_k(self);
let repetition_penalty = CommonExtProvider::get_repetition_penalty(self); let repetition_penalty = CommonExtProvider::get_repetition_penalty(self);
let include_stop_str_in_output = CommonExtProvider::get_include_stop_str_in_output(self);
if let Some(nvext) = self.nvext() { if let Some(nvext) = self.nvext() {
let greedy = nvext.greed_sampling.unwrap_or(false); let greedy = nvext.greed_sampling.unwrap_or(false);
...@@ -141,6 +142,7 @@ impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvid ...@@ -141,6 +142,7 @@ impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvid
use_beam_search: None, use_beam_search: None,
length_penalty: None, length_penalty: None,
guided_decoding, guided_decoding,
include_stop_str_in_output,
}) })
} }
} }
......
...@@ -216,6 +216,10 @@ impl CommonExtProvider for NvCreateChatCompletionRequest { ...@@ -216,6 +216,10 @@ impl CommonExtProvider for NvCreateChatCompletionRequest {
.and_then(|nv| nv.repetition_penalty.as_ref()), .and_then(|nv| nv.repetition_penalty.as_ref()),
) )
} }
fn get_include_stop_str_in_output(&self) -> Option<bool> {
self.common.include_stop_str_in_output
}
} }
/// Implements `OpenAIStopConditionsProvider` for `NvCreateChatCompletionRequest`, /// Implements `OpenAIStopConditionsProvider` for `NvCreateChatCompletionRequest`,
......
...@@ -35,6 +35,11 @@ pub struct CommonExt { ...@@ -35,6 +35,11 @@ pub struct CommonExt {
#[validate(range(exclusive_min = 0.0, max = 2.0))] #[validate(range(exclusive_min = 0.0, max = 2.0))]
pub repetition_penalty: Option<f32>, pub repetition_penalty: Option<f32>,
/// include_stop_str_in_output
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub include_stop_str_in_output: Option<bool>,
/// Guided Decoding Options /// Guided Decoding Options
/// If specified, the output will be a JSON object. Can be a string, an object, or null. /// If specified, the output will be a JSON object. Can be a string, an object, or null.
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
...@@ -83,6 +88,7 @@ pub trait CommonExtProvider { ...@@ -83,6 +88,7 @@ pub trait CommonExtProvider {
/// Other sampling Options /// Other sampling Options
fn get_top_k(&self) -> Option<i32>; fn get_top_k(&self) -> Option<i32>;
fn get_repetition_penalty(&self) -> Option<f32>; fn get_repetition_penalty(&self) -> Option<f32>;
fn get_include_stop_str_in_output(&self) -> Option<bool>;
} }
/// Helper function to emit deprecation warnings for nvext parameters /// Helper function to emit deprecation warnings for nvext parameters
...@@ -132,6 +138,7 @@ mod tests { ...@@ -132,6 +138,7 @@ mod tests {
assert_eq!(common_ext.guided_grammar, None); assert_eq!(common_ext.guided_grammar, None);
assert_eq!(common_ext.guided_choice, None); assert_eq!(common_ext.guided_choice, None);
assert_eq!(common_ext.guided_decoding_backend, None); assert_eq!(common_ext.guided_decoding_backend, None);
assert_eq!(common_ext.include_stop_str_in_output, None);
} }
#[test] #[test]
...@@ -141,6 +148,7 @@ mod tests { ...@@ -141,6 +148,7 @@ mod tests {
.min_tokens(10) .min_tokens(10)
.top_k(50) .top_k(50)
.repetition_penalty(1.2) .repetition_penalty(1.2)
.include_stop_str_in_output(true)
.guided_json(serde_json::json!({"key": "value"})) .guided_json(serde_json::json!({"key": "value"}))
.guided_regex("regex".to_string()) .guided_regex("regex".to_string())
.guided_grammar("grammar".to_string()) .guided_grammar("grammar".to_string())
...@@ -153,6 +161,7 @@ mod tests { ...@@ -153,6 +161,7 @@ mod tests {
assert_eq!(common_ext.min_tokens, Some(10)); assert_eq!(common_ext.min_tokens, Some(10));
assert_eq!(common_ext.top_k, Some(50)); assert_eq!(common_ext.top_k, Some(50));
assert_eq!(common_ext.repetition_penalty, Some(1.2)); assert_eq!(common_ext.repetition_penalty, Some(1.2));
assert_eq!(common_ext.include_stop_str_in_output, Some(true));
assert_eq!( assert_eq!(
common_ext.guided_json.as_ref(), common_ext.guided_json.as_ref(),
Some(&serde_json::json!({"key": "value"})) Some(&serde_json::json!({"key": "value"}))
...@@ -175,11 +184,13 @@ mod tests { ...@@ -175,11 +184,13 @@ mod tests {
let common_ext = CommonExt::builder() let common_ext = CommonExt::builder()
.ignore_eos(false) .ignore_eos(false)
.min_tokens(5) .min_tokens(5)
.include_stop_str_in_output(true)
.build() .build()
.unwrap(); .unwrap();
assert_eq!(common_ext.ignore_eos, Some(false)); assert_eq!(common_ext.ignore_eos, Some(false));
assert_eq!(common_ext.min_tokens, Some(5)); assert_eq!(common_ext.min_tokens, Some(5));
assert_eq!(common_ext.include_stop_str_in_output, Some(true));
} }
#[test] #[test]
...@@ -190,6 +201,7 @@ mod tests { ...@@ -190,6 +201,7 @@ mod tests {
min_tokens: Some(0), // Should be valid (min = 0) min_tokens: Some(0), // Should be valid (min = 0)
top_k: None, top_k: None,
repetition_penalty: None, repetition_penalty: None,
include_stop_str_in_output: None,
guided_json: None, guided_json: None,
guided_regex: None, guided_regex: None,
guided_grammar: None, guided_grammar: None,
...@@ -208,6 +220,7 @@ mod tests { ...@@ -208,6 +220,7 @@ mod tests {
assert_eq!(common_ext.min_tokens, None); assert_eq!(common_ext.min_tokens, None);
assert_eq!(common_ext.top_k, None); assert_eq!(common_ext.top_k, None);
assert_eq!(common_ext.repetition_penalty, None); assert_eq!(common_ext.repetition_penalty, None);
assert_eq!(common_ext.include_stop_str_in_output, None);
assert!(common_ext.validate().is_ok()); assert!(common_ext.validate().is_ok());
} }
...@@ -220,6 +233,7 @@ mod tests { ...@@ -220,6 +233,7 @@ mod tests {
assert_eq!(common_ext.min_tokens, None); assert_eq!(common_ext.min_tokens, None);
assert_eq!(common_ext.top_k, None); assert_eq!(common_ext.top_k, None);
assert_eq!(common_ext.repetition_penalty, None); assert_eq!(common_ext.repetition_penalty, None);
assert_eq!(common_ext.include_stop_str_in_output, None);
assert!(common_ext.validate().is_ok()); assert!(common_ext.validate().is_ok());
} }
......
...@@ -210,6 +210,10 @@ impl CommonExtProvider for NvCreateCompletionRequest { ...@@ -210,6 +210,10 @@ impl CommonExtProvider for NvCreateCompletionRequest {
.and_then(|nv| nv.repetition_penalty.as_ref()), .and_then(|nv| nv.repetition_penalty.as_ref()),
) )
} }
fn get_include_stop_str_in_output(&self) -> Option<bool> {
self.common.include_stop_str_in_output
}
} }
impl OpenAIStopConditionsProvider for NvCreateCompletionRequest { impl OpenAIStopConditionsProvider for NvCreateCompletionRequest {
......
...@@ -25,6 +25,52 @@ fn test_chat_completions_ignore_eos_from_common() { ...@@ -25,6 +25,52 @@ fn test_chat_completions_ignore_eos_from_common() {
assert_eq!(request.common.ignore_eos, Some(true)); assert_eq!(request.common.ignore_eos, Some(true));
assert_eq!(request.common.min_tokens, Some(100)); assert_eq!(request.common.min_tokens, Some(100));
assert_eq!(request.common.include_stop_str_in_output, None);
}
#[test]
fn test_chat_completions_include_stop_str_in_output_from_common() {
let json_str = r#"{
"model": "test-model",
"messages": [{"role": "user", "content": "Hello"}],
"include_stop_str_in_output": true
}"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
assert_eq!(request.common.include_stop_str_in_output, Some(true));
assert_eq!(request.get_include_stop_str_in_output(), Some(true));
}
#[test]
fn test_completions_include_stop_str_in_output_from_common() {
let json_str = r#"{
"model": "test-model",
"prompt": "Hello world",
"include_stop_str_in_output": true
}"#;
let request: NvCreateCompletionRequest = serde_json::from_str(json_str).unwrap();
assert_eq!(request.common.include_stop_str_in_output, Some(true));
// When exposed on completions, this should also be available via the provider
assert_eq!(request.get_include_stop_str_in_output(), Some(true));
}
#[test]
fn test_sampling_parameters_include_stop_str_in_output_extraction() {
use dynamo_llm::protocols::common::SamplingOptionsProvider;
let request = NvCreateChatCompletionRequest {
inner: Default::default(),
common: CommonExt::builder()
.include_stop_str_in_output(true)
.build()
.unwrap(),
nvext: None,
};
let sampling = request.extract_sampling_options().unwrap();
assert_eq!(sampling.include_stop_str_in_output, Some(true));
} }
#[test] #[test]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment