Unverified Commit 441473c3 authored by Ryan McCormick's avatar Ryan McCormick Committed by GitHub
Browse files

feat: Add support for skip_special_tokens parameter in v1/completions and...

feat: Add support for skip_special_tokens parameter in v1/completions and v1/chat/completions endpoints (#4175)
parent 14af074e
......@@ -94,12 +94,13 @@ impl Backend {
stream: ManyOut<ExecutionOutputStream>,
prompt_token_ids: &[TokenIdType],
stop_conditions: StopConditions,
skip_special_tokens: bool,
) -> anyhow::Result<DecoderUnfoldState> {
let Some(tokenizer) = self.tokenizer.as_ref() else {
anyhow::bail!("Backend built from blank ModelDeploymentCard, no tokenizer");
};
let decoder = Decoder::new(
tokenizer.decode_stream(prompt_token_ids, false),
tokenizer.decode_stream(prompt_token_ids, skip_special_tokens),
stop_conditions,
);
......@@ -129,10 +130,18 @@ impl
let prompt_token_ids = request.token_ids.clone();
// TODO: Consider updating default to true to match behavior of other frameworks
let skip_special_tokens = request.output_options.skip_special_tokens.unwrap_or(false);
let next_stream = next.generate(request).await?;
let context = next_stream.context();
let state = self.decoder(next_stream, &prompt_token_ids, stop_conditions)?;
let state = self.decoder(
next_stream,
&prompt_token_ids,
stop_conditions,
skip_special_tokens,
)?;
let processed_stream = stream::unfold(state, |mut state| async move {
match state.stream.next().await {
......
......@@ -473,8 +473,6 @@ pub struct OutputOptions {
pub prompt_logprobs: Option<u32>,
/// Whether to skip special tokens in the output.
/// spaces_between_special_tokens: Whether to add spaces between special
/// tokens in the output. Defaults to True.
pub skip_special_tokens: Option<bool>,
/// If true, the Context object will contain the prompt that was pass to
......
......@@ -198,6 +198,10 @@ impl CommonExtProvider for NvCreateChatCompletionRequest {
fn get_include_stop_str_in_output(&self) -> Option<bool> {
self.common.include_stop_str_in_output
}
fn get_skip_special_tokens(&self) -> Option<bool> {
self.common.skip_special_tokens
}
}
/// Implements `OpenAIStopConditionsProvider` for `NvCreateChatCompletionRequest`,
......@@ -263,7 +267,7 @@ impl OpenAIOutputOptionsProvider for NvCreateChatCompletionRequest {
}
fn get_skip_special_tokens(&self) -> Option<bool> {
None
CommonExtProvider::get_skip_special_tokens(self)
}
fn get_formatted_prompt(&self) -> Option<bool> {
......@@ -316,3 +320,53 @@ impl ValidateRequest for NvCreateChatCompletionRequest {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocols::common::OutputOptionsProvider;
use serde_json::json;
#[test]
fn test_skip_special_tokens_none() {
let json_str = json!({
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
]
});
let request: NvCreateChatCompletionRequest =
serde_json::from_value(json_str).expect("Failed to deserialize request");
assert_eq!(request.common.skip_special_tokens, None);
let output_options = request
.extract_output_options()
.expect("Failed to extract output options");
assert_eq!(output_options.skip_special_tokens, None);
}
#[test]
fn test_skip_special_tokens_propagates() {
for skip_value in [true, false] {
let json_str = json!({
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
],
"skip_special_tokens": skip_value
});
let request: NvCreateChatCompletionRequest =
serde_json::from_value(json_str).expect("Failed to deserialize request");
let output_options = request
.extract_output_options()
.expect("Failed to extract output options");
assert_eq!(output_options.skip_special_tokens, Some(skip_value));
}
}
}
......@@ -72,6 +72,14 @@ pub struct CommonExt {
#[builder(default, setter(strip_option))]
#[allow(unused)] // Not used
pub guided_whitespace_pattern: Option<String>,
/// Whether to skip special tokens in the decoded output.
/// When true, special tokens (like EOS, BOS, PAD) are removed from the output text.
/// When false, special tokens are included in the output text.
/// Defaults to false if not specified.
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub skip_special_tokens: Option<bool>,
}
impl CommonExt {
......@@ -99,6 +107,9 @@ pub trait CommonExtProvider {
fn get_min_p(&self) -> Option<f32>;
fn get_repetition_penalty(&self) -> Option<f32>;
fn get_include_stop_str_in_output(&self) -> Option<bool>;
/// Output Options
fn get_skip_special_tokens(&self) -> Option<bool>;
}
#[cfg(test)]
......@@ -120,6 +131,7 @@ mod tests {
assert_eq!(common_ext.guided_choice, None);
assert_eq!(common_ext.guided_decoding_backend, None);
assert_eq!(common_ext.include_stop_str_in_output, None);
assert_eq!(common_ext.skip_special_tokens, None);
}
#[test]
......@@ -135,6 +147,7 @@ mod tests {
.guided_grammar("grammar".to_string())
.guided_choice(vec!["choice1".to_string(), "choice2".to_string()])
.guided_decoding_backend("backend".to_string())
.skip_special_tokens(false)
.build()
.unwrap();
......@@ -157,6 +170,7 @@ mod tests {
common_ext.guided_decoding_backend,
Some("backend".to_string())
);
assert_eq!(common_ext.skip_special_tokens, Some(false));
}
#[test]
......@@ -190,6 +204,7 @@ mod tests {
guided_choice: None,
guided_decoding_backend: None,
guided_whitespace_pattern: None,
skip_special_tokens: None,
};
assert!(common_ext.validate().is_ok());
}
......@@ -219,4 +234,52 @@ mod tests {
assert_eq!(common_ext.include_stop_str_in_output, None);
assert!(common_ext.validate().is_ok());
}
#[test]
fn test_skip_special_tokens_field() {
// Test that skip_special_tokens can be set and retrieved
let common_ext = CommonExt::builder()
.skip_special_tokens(true)
.build()
.unwrap();
assert_eq!(common_ext.skip_special_tokens, Some(true));
let common_ext = CommonExt::builder()
.skip_special_tokens(false)
.build()
.unwrap();
assert_eq!(common_ext.skip_special_tokens, Some(false));
}
#[test]
fn test_skip_special_tokens_serialization() {
// Test that skip_special_tokens can be serialized and deserialized
let common_ext = CommonExt::builder()
.skip_special_tokens(true)
.build()
.unwrap();
let json = serde_json::to_string(&common_ext).unwrap();
let deserialized: CommonExt = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.skip_special_tokens, Some(true));
// Test with false value
let common_ext = CommonExt::builder()
.skip_special_tokens(false)
.build()
.unwrap();
let json = serde_json::to_string(&common_ext).unwrap();
let deserialized: CommonExt = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.skip_special_tokens, Some(false));
// Test that None is not serialized (skip_serializing_if = "Option::is_none")
let common_ext = CommonExt::builder().build().unwrap();
let json = serde_json::to_string(&common_ext).unwrap();
assert!(!json.contains("skip_special_tokens"));
}
}
......@@ -222,6 +222,10 @@ impl CommonExtProvider for NvCreateCompletionRequest {
fn get_include_stop_str_in_output(&self) -> Option<bool> {
self.common.include_stop_str_in_output
}
fn get_skip_special_tokens(&self) -> Option<bool> {
self.common.skip_special_tokens
}
}
impl OpenAIStopConditionsProvider for NvCreateCompletionRequest {
......@@ -397,7 +401,7 @@ impl OpenAIOutputOptionsProvider for NvCreateCompletionRequest {
}
fn get_skip_special_tokens(&self) -> Option<bool> {
None
CommonExtProvider::get_skip_special_tokens(self)
}
fn get_formatted_prompt(&self) -> Option<bool> {
......@@ -444,3 +448,49 @@ impl ValidateRequest for NvCreateCompletionRequest {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocols::common::OutputOptionsProvider;
use serde_json::json;
#[test]
fn test_skip_special_tokens_none() {
let json_str = json!({
"model": "test-model",
"prompt": "Hello, world!"
});
let request: NvCreateCompletionRequest =
serde_json::from_value(json_str).expect("Failed to deserialize request");
assert_eq!(request.common.skip_special_tokens, None);
let output_options = request
.extract_output_options()
.expect("Failed to extract output options");
assert_eq!(output_options.skip_special_tokens, None);
}
#[test]
fn test_skip_special_tokens_propagates() {
for skip_value in [true, false] {
let json_str = json!({
"model": "test-model",
"prompt": "Hello, world!",
"skip_special_tokens": skip_value
});
let request: NvCreateCompletionRequest =
serde_json::from_value(json_str).expect("Failed to deserialize request");
let output_options = request
.extract_output_options()
.expect("Failed to extract output options");
assert_eq!(output_options.skip_special_tokens, Some(skip_value));
}
}
}
......@@ -178,3 +178,31 @@ fn test_long_sequence_incremental_decode_with_prefill() {
assert_eq!(output.trim(), output_text.to_string());
}
}
#[test]
fn test_decode_with_skip_special_tokens() {
let tokenizer = HuggingFaceTokenizer::from_file(TINYLLAMA_TOKENIZER_PATH)
.expect("Failed to load remote HuggingFace tokenizer");
// Create a sequence with special tokens:
// <s> (token_id: 1) + "Hello world" + </s> (token_id: 2)
let text = "Hello world";
let encoding = tokenizer.encode(text).expect("Failed to encode text");
let mut token_ids = vec![1]; // <s>
token_ids.extend(encoding.token_ids());
token_ids.push(2); // </s>
// Decode with skip_special_tokens = false (should keep special tokens)
let decoded_with_special = tokenizer
.decode(&token_ids, false)
.expect("Failed to decode with skip_special_tokens=false");
// Decode with skip_special_tokens = true (should remove special tokens)
let decoded_without_special = tokenizer
.decode(&token_ids, true)
.expect("Failed to decode with skip_special_tokens=true");
// Validate exact matches on the entire decoded strings
assert_eq!(decoded_with_special, "<s> Hello world</s>");
assert_eq!(decoded_without_special, "Hello world");
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment