Unverified Commit 3d2d7e47 authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

chore: enable gpt oss reasoning to work with text and tokens (#3321)


Signed-off-by: default avatarAyush Agarwal <ayushag@nvidia.com>
parent f4a3a6b6
...@@ -56,12 +56,28 @@ impl GptOssReasoningParser { ...@@ -56,12 +56,28 @@ impl GptOssReasoningParser {
} }
} }
fn encode_text_to_tokens(text: &str) -> anyhow::Result<Vec<u32>> {
let enc = get_harmony_encoding()
.as_ref()
.map_err(|e| anyhow::anyhow!("Failed to get harmony encoding: {e}"))?;
Ok(enc.tokenizer().encode_with_special_tokens(text))
}
impl ReasoningParser for GptOssReasoningParser { impl ReasoningParser for GptOssReasoningParser {
fn detect_and_parse_reasoning(&mut self, _text: &str, token_ids: &[u32]) -> ParserResult { fn detect_and_parse_reasoning(&mut self, text: &str, token_ids: &[u32]) -> ParserResult {
tracing::debug!( let token_ids = if token_ids.is_empty() {
"detect_and_parse_reasoning called with {} token_ids", // WAR: Since we are moving to just text based reasoning parsing, converting to token_ids now using harmony encoding
token_ids.len() let encoded_tokens = match encode_text_to_tokens(text) {
); Ok(tokens) => tokens,
Err(err) => {
tracing::warn!("Failed to encode Harmony tokens: {err}");
return ParserResult::default();
}
};
&encoded_tokens.to_vec()
} else {
token_ids
};
let parser = &mut self.parser; let parser = &mut self.parser;
...@@ -153,10 +169,19 @@ impl ReasoningParser for GptOssReasoningParser { ...@@ -153,10 +169,19 @@ impl ReasoningParser for GptOssReasoningParser {
text: &str, text: &str,
token_ids: &[u32], token_ids: &[u32],
) -> ParserResult { ) -> ParserResult {
tracing::debug!( let token_ids = if token_ids.is_empty() {
"parse_reasoning_streaming_incremental called with {} token_ids", // WAR: Since we are moving to just text based reasoning parsing, converting to token_ids now using harmony encoding
token_ids.len() let encoded_tokens = match encode_text_to_tokens(text) {
); Ok(tokens) => tokens,
Err(err) => {
tracing::warn!("Failed to encode Harmony tokens: {err}");
return ParserResult::default();
}
};
&encoded_tokens.to_vec()
} else {
token_ids
};
let parser: &mut StreamableParser = &mut self.parser; let parser: &mut StreamableParser = &mut self.parser;
let mut normal_delta = String::new(); let mut normal_delta = String::new();
...@@ -261,7 +286,6 @@ impl ReasoningParser for GptOssReasoningParser { ...@@ -261,7 +286,6 @@ impl ReasoningParser for GptOssReasoningParser {
tracing::warn!("Shouldn't be delta content after in channel: {}", channel); tracing::warn!("Shouldn't be delta content after in channel: {}", channel);
} }
} }
tracing::debug!("No deltas to return, returning empty result"); tracing::debug!("No deltas to return, returning empty result");
ParserResult::default() ParserResult::default()
} }
...@@ -274,12 +298,8 @@ mod tests { ...@@ -274,12 +298,8 @@ mod tests {
#[test] #[test]
fn test_gpt_oss_reasoning_parser() { fn test_gpt_oss_reasoning_parser() {
let mut parser = GptOssReasoningParser::new().expect("Failed to create parser"); let mut parser = GptOssReasoningParser::new().expect("Failed to create parser");
let enc = get_harmony_encoding()
.as_ref()
.expect("Failed to get encoding");
let text = "<|channel|>analysis<|message|>The user asks a simple factual question: capital of Brazil. The answer is Brasília. No additional explanation needed.<|end|><|start|>assistant<|channel|>final<|message|>The capital of Brazil is Brasília."; let text = "<|channel|>analysis<|message|>The user asks a simple factual question: capital of Brazil. The answer is Brasília. No additional explanation needed.<|end|><|start|>assistant<|channel|>final<|message|>The capital of Brazil is Brasília.";
let token_ids = enc.tokenizer().encode_with_special_tokens(text); // Example token IDs let result = parser.detect_and_parse_reasoning(text, &[]);
let result = parser.detect_and_parse_reasoning("Test text", &token_ids);
assert!(result.normal_text == "The capital of Brazil is Brasília."); assert!(result.normal_text == "The capital of Brazil is Brasília.");
assert!( assert!(
result.reasoning_text result.reasoning_text
...@@ -290,15 +310,17 @@ mod tests { ...@@ -290,15 +310,17 @@ mod tests {
#[test] #[test]
fn test_gpt_oss_reasoning_parser_streaming() { fn test_gpt_oss_reasoning_parser_streaming() {
let mut parser = GptOssReasoningParser::new().expect("Failed to create parser"); let mut parser = GptOssReasoningParser::new().expect("Failed to create parser");
let enc = get_harmony_encoding() let chunks = vec![
.as_ref() "<|channel|>",
.expect("Failed to get encoding"); "analysis<|message|>The user asks a simple factual question: capital of Brazil.",
let text = "<|channel|>analysis<|message|>The user asks a simple factual question: capital of Brazil. The answer is Brasília. No additional explanation needed.<|end|><|start|>assistant<|channel|>final<|message|>The capital of Brazil is Brasília."; " The answer is Brasília. No additional explanation needed.",
let token_ids = enc.tokenizer().encode_with_special_tokens(text); // Example token IDs "<|end|><|start|>assistant<|channel|>final<|message|>",
"The capital of Brazil is Brasília.",
];
let mut reasoning_text_incr = String::new(); let mut reasoning_text_incr = String::new();
let mut normal_text_incr = String::new(); let mut normal_text_incr = String::new();
for token in token_ids.iter() { for chunk in chunks {
let result = parser.parse_reasoning_streaming_incremental("Test text", &[(*token)]); let result = parser.parse_reasoning_streaming_incremental(chunk, &[]);
normal_text_incr.push_str(&result.normal_text); normal_text_incr.push_str(&result.normal_text);
reasoning_text_incr.push_str(&result.reasoning_text); reasoning_text_incr.push_str(&result.reasoning_text);
} }
...@@ -337,15 +359,15 @@ mod tests { ...@@ -337,15 +359,15 @@ mod tests {
"The user asks a simple factual question: capital of Brazil. The answer is Brasília. No additional explanation needed." "The user asks a simple factual question: capital of Brazil. The answer is Brasília. No additional explanation needed."
); );
} }
}
#[test] #[test]
fn test_gpt_oss_reasoning_parser_streaming_variable_length_chunks() { fn test_gpt_oss_reasoning_parser_streaming_variable_length_chunks() {
let text = "<|channel|>analysis<|message|>User asks: \"Hey, quick check: is everything up and running?\" We should check system health using the provided function get_system_health. Use function.<|end|><|start|>assistant<|channel|>commentary to=functions.get_system_health <|constrain|>json<|message|>{}"; let text = "<|channel|>analysis<|message|>User asks: \"Hey, quick check: is everything up and running?\" We should check system health using the provided function get_system_health. Use function.<|end|><|start|>assistant<|channel|>commentary to=functions.get_system_health <|constrain|>json<|message|>{}";
let enc = get_harmony_encoding() let enc = get_harmony_encoding()
.as_ref() .as_ref()
.expect("Failed to get encoding"); .expect("Failed to get encoding");
let token_ids = enc.tokenizer().encode_with_special_tokens(text); // Example token IDs let token_ids = enc.tokenizer().encode_with_special_tokens(text);
// Send token one by one // Send token one by one
{ {
let mut parser = GptOssReasoningParser::new().expect("Failed to create parser"); let mut parser = GptOssReasoningParser::new().expect("Failed to create parser");
...@@ -366,6 +388,7 @@ fn test_gpt_oss_reasoning_parser_streaming_variable_length_chunks() { ...@@ -366,6 +388,7 @@ fn test_gpt_oss_reasoning_parser_streaming_variable_length_chunks() {
"<|channel|>commentary to=functions.get_system_health <|constrain|>json<|message|>" "<|channel|>commentary to=functions.get_system_health <|constrain|>json<|message|>"
); );
} }
// Send token in chunks (chunking obtained from actual model output) // Send token in chunks (chunking obtained from actual model output)
{ {
let mut parser = GptOssReasoningParser::new().expect("Failed to create parser"); let mut parser = GptOssReasoningParser::new().expect("Failed to create parser");
...@@ -384,7 +407,7 @@ fn test_gpt_oss_reasoning_parser_streaming_variable_length_chunks() { ...@@ -384,7 +407,7 @@ fn test_gpt_oss_reasoning_parser_streaming_variable_length_chunks() {
], ],
vec![12083], vec![12083],
]; ];
// concatenate chunk tokens and verify they match original token_ids // Concatenate chunk tokens and verify they match original token_ids
let concatenated: Vec<u32> = chunk_tokens.iter().flatten().copied().collect(); let concatenated: Vec<u32> = chunk_tokens.iter().flatten().copied().collect();
assert_eq!(concatenated, token_ids); assert_eq!(concatenated, token_ids);
...@@ -402,4 +425,5 @@ fn test_gpt_oss_reasoning_parser_streaming_variable_length_chunks() { ...@@ -402,4 +425,5 @@ fn test_gpt_oss_reasoning_parser_streaming_variable_length_chunks() {
"<|channel|>commentary to=functions.get_system_health <|constrain|>json<|message|>" "<|channel|>commentary to=functions.get_system_health <|constrain|>json<|message|>"
); );
} }
}
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment