Unverified Commit 95042f82 authored by zhongdaor-nv's avatar zhongdaor-nv Committed by GitHub
Browse files

fix: Fix gpt_oss_parser last_content_delta to accumulate tokens + add regression test (#3301)


Signed-off-by: default avatarzhongdaor <zhongdaor@nvidia.com>
Signed-off-by: default avatarzhongdaor-nv <zhongdaor@nvidia.com>
Co-authored-by: default avatarGuanLuo <41310872+GuanLuo@users.noreply.github.com>
parent d354763c
...@@ -159,6 +159,9 @@ impl ReasoningParser for GptOssReasoningParser { ...@@ -159,6 +159,9 @@ impl ReasoningParser for GptOssReasoningParser {
); );
let parser: &mut StreamableParser = &mut self.parser; let parser: &mut StreamableParser = &mut self.parser;
let mut normal_delta = String::new();
let mut reasoning_delta = String::new();
for (i, token_id) in token_ids.iter().enumerate() { for (i, token_id) in token_ids.iter().enumerate() {
tracing::debug!( tracing::debug!(
"Processing streaming token {} of {}: {}", "Processing streaming token {} of {}: {}",
...@@ -170,26 +173,42 @@ impl ReasoningParser for GptOssReasoningParser { ...@@ -170,26 +173,42 @@ impl ReasoningParser for GptOssReasoningParser {
tracing::warn!("Harmony parse error for token_id {token_id}: {e}"); tracing::warn!("Harmony parse error for token_id {token_id}: {e}");
return ParserResult::default(); return ParserResult::default();
} }
}
if let Some(channel) = self.parser.current_channel() { if let (Some(delta), Some(channel)) = (
tracing::debug!("Current channel {}", channel); parser.last_content_delta().unwrap_or_default(),
if channel == "final" { parser.current_channel(),
// If we're in the final channel, we should not parse reasoning ) {
if let Some(current) = self.parser.last_content_delta().unwrap_or_default() { // `last_content_delta` only exposes the newest token slice, so we forward
tracing::debug!("Got normal text delta of {} chars", current.len()); // `final`/`analysis` chunks immediately; commentary is reconstructed in the
return ParserResult { // fallback path below because it needs the stripped metadata.
normal_text: current, match channel.as_str() {
reasoning_text: String::new(), "final" => normal_delta.push_str(&delta),
}; "analysis" => reasoning_delta.push_str(&delta),
"commentary" => {}
_ => {}
} }
tracing::debug!("No content delta in final channel"); }
ParserResult::default() }
} else if channel == "commentary" {
if !normal_delta.is_empty() || !reasoning_delta.is_empty() {
tracing::debug!(
"Returning aggregated deltas: normal: {} chars, reasoning: {} chars",
normal_delta.len(),
reasoning_delta.len()
);
return ParserResult {
normal_text: normal_delta,
reasoning_text: reasoning_delta,
};
}
if let Some(channel) = parser.current_channel() {
if channel == "commentary" {
tracing::debug!("In commentary channel, recovering full content");
// If we're in the commentary channel, we should return raw token content and recover content that has been consumed by the parser // If we're in the commentary channel, we should return raw token content and recover content that has been consumed by the parser
// so that the tool parser can process it properly // so that the tool parser can process it properly
if let Ok(enc) = get_harmony_encoding() { if let Ok(enc) = get_harmony_encoding() {
let current_content = self.parser.current_content().unwrap_or_default(); let current_content = parser.current_content().unwrap_or_default();
let mut final_text = text.to_string(); let mut final_text = text.to_string();
// Restore commentary metadata consumed by the parser so the tool-call parser can // Restore commentary metadata consumed by the parser so the tool-call parser can
...@@ -206,7 +225,7 @@ impl ReasoningParser for GptOssReasoningParser { ...@@ -206,7 +225,7 @@ impl ReasoningParser for GptOssReasoningParser {
// Recovery should only happen once, and only when `current_content` is empty. // Recovery should only happen once, and only when `current_content` is empty.
if current_content.is_empty() { if current_content.is_empty() {
let tokens = self.parser.tokens(); let tokens = parser.tokens();
// Get the token id for " <|channel|>" // Get the token id for " <|channel|>"
let channel_token_id = enc let channel_token_id = enc
...@@ -222,43 +241,29 @@ impl ReasoningParser for GptOssReasoningParser { ...@@ -222,43 +241,29 @@ impl ReasoningParser for GptOssReasoningParser {
}) })
.unwrap_or(0); .unwrap_or(0);
// Then get the generated text from the last <|channel|> to the end of self.parser.tokens() // Then get the generated text from the last <|channel|> to the end of parser.tokens()
let end_token_idx = self.parser.tokens().len(); let end_token_idx = parser.tokens().len();
// Use Harmony's decode_utf8 to decode tokens into text // Use Harmony's decode_utf8 to decode tokens into text
let generated_text = enc let generated_text = enc
.tokenizer() .tokenizer()
.decode_utf8( .decode_utf8(&parser.tokens()[last_channel_token_idx..end_token_idx])
&self.parser.tokens()[last_channel_token_idx..end_token_idx],
)
.unwrap_or_default(); .unwrap_or_default();
final_text = generated_text; final_text = generated_text;
} }
ParserResult { return ParserResult {
normal_text: final_text, normal_text: final_text,
reasoning_text: String::new(), reasoning_text: String::new(),
}
} else {
tracing::warn!("Failed to get harmony encoding for raw token decoding");
ParserResult::default()
}
} else {
tracing::debug!("In reasoning channel: {}", channel);
if let Some(current) = self.parser.last_content_delta().unwrap_or_default() {
tracing::debug!("Got reasoning text delta of {} chars", current.len());
return ParserResult {
normal_text: String::new(),
reasoning_text: current,
}; };
} }
tracing::debug!("No content delta in reasoning channel"); } else {
ParserResult::default() tracing::warn!("Shouldn't be delta content after in channel: {}", channel);
} }
} else {
tracing::debug!("No current channel detected");
ParserResult::default()
} }
tracing::debug!("No deltas to return, returning empty result");
ParserResult::default()
} }
} }
...@@ -303,4 +308,98 @@ mod tests { ...@@ -303,4 +308,98 @@ mod tests {
== "The user asks a simple factual question: capital of Brazil. The answer is Brasília. No additional explanation needed." == "The user asks a simple factual question: capital of Brazil. The answer is Brasília. No additional explanation needed."
); );
} }
#[test]
fn test_gpt_oss_reasoning_parser_streaming_chunked() {
let mut parser = GptOssReasoningParser::new().expect("Failed to create parser");
let enc = get_harmony_encoding()
.as_ref()
.expect("Failed to get encoding");
let text = "<|channel|>analysis<|message|>The user asks a simple factual question: capital of Brazil. The answer is Brasília. No additional explanation needed.<|end|><|start|>assistant<|channel|>final<|message|>The capital of Brazil is Brasília.";
let token_ids = enc.tokenizer().encode_with_special_tokens(text);
let mut reasoning_text_incr = String::new();
let mut normal_text_incr = String::new();
let mut idx = 0;
let chunk_size = 4;
while idx < token_ids.len() {
let end = (idx + chunk_size).min(token_ids.len());
let result =
parser.parse_reasoning_streaming_incremental("Test text", &token_ids[idx..end]);
normal_text_incr.push_str(&result.normal_text);
reasoning_text_incr.push_str(&result.reasoning_text);
idx = end;
}
assert_eq!(normal_text_incr, "The capital of Brazil is Brasília.");
assert_eq!(
reasoning_text_incr,
"The user asks a simple factual question: capital of Brazil. The answer is Brasília. No additional explanation needed."
);
}
}
#[test]
fn test_gpt_oss_reasoning_parser_streaming_variable_length_chunks() {
let text = "<|channel|>analysis<|message|>User asks: \"Hey, quick check: is everything up and running?\" We should check system health using the provided function get_system_health. Use function.<|end|><|start|>assistant<|channel|>commentary to=functions.get_system_health <|constrain|>json<|message|>{}";
let enc = get_harmony_encoding()
.as_ref()
.expect("Failed to get encoding");
let token_ids = enc.tokenizer().encode_with_special_tokens(text); // Example token IDs
// Send token one by one
{
let mut parser = GptOssReasoningParser::new().expect("Failed to create parser");
let mut reasoning_text_incr = String::new();
let mut normal_text_incr = String::new();
for token in token_ids.iter() {
let result = parser.parse_reasoning_streaming_incremental("", &[(*token)]);
normal_text_incr.push_str(&result.normal_text);
reasoning_text_incr.push_str(&result.reasoning_text);
}
assert_eq!(
reasoning_text_incr,
"User asks: \"Hey, quick check: is everything up and running?\" We should check system health using the provided function get_system_health. Use function."
);
// [gluo TODO] missing "<|start|>assistant" and "{}" from original message
assert_eq!(
normal_text_incr,
"<|channel|>commentary to=functions.get_system_health <|constrain|>json<|message|>"
);
}
// Send token in chunks (chunking obtained from actual model output)
{
let mut parser = GptOssReasoningParser::new().expect("Failed to create parser");
let mut reasoning_text_incr = String::new();
let mut normal_text_incr = String::new();
let chunk_tokens = [
vec![200005],
vec![35644, 200008, 1844, 31064, 25, 392, 25216, 11, 4853],
vec![2371, 25, 382, 5519, 869, 326, 6788, 16842, 1416, 1757],
vec![2371, 2420, 3230, 2360, 290, 5181, 1114, 717, 39303, 126214],
vec![
13, 7649, 1114, 13, 200007, 200006, 173781, 200005, 12606, 815,
],
vec![
316, 28, 44580, 775, 39303, 126214, 220, 200003, 4108, 200008,
],
vec![12083],
];
// concatenate chunk tokens and verify they match original token_ids
let concatenated: Vec<u32> = chunk_tokens.iter().flatten().copied().collect();
assert_eq!(concatenated, token_ids);
for token in chunk_tokens.iter() {
let result = parser.parse_reasoning_streaming_incremental("", token);
normal_text_incr.push_str(&result.normal_text);
reasoning_text_incr.push_str(&result.reasoning_text);
}
assert_eq!(
reasoning_text_incr,
"User asks: \"Hey, quick check: is everything up and running?\" We should check system health using the provided function get_system_health. Use function."
);
assert_eq!(
normal_text_incr,
"<|channel|>commentary to=functions.get_system_health <|constrain|>json<|message|>"
);
}
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment