Unverified Commit 3d2d7e47 authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

chore: enable gpt oss reasoning to work with text and tokens (#3321)


Signed-off-by: default avatarAyush Agarwal <ayushag@nvidia.com>
parent f4a3a6b6
......@@ -56,12 +56,28 @@ impl GptOssReasoningParser {
}
}
fn encode_text_to_tokens(text: &str) -> anyhow::Result<Vec<u32>> {
let enc = get_harmony_encoding()
.as_ref()
.map_err(|e| anyhow::anyhow!("Failed to get harmony encoding: {e}"))?;
Ok(enc.tokenizer().encode_with_special_tokens(text))
}
impl ReasoningParser for GptOssReasoningParser {
fn detect_and_parse_reasoning(&mut self, _text: &str, token_ids: &[u32]) -> ParserResult {
tracing::debug!(
"detect_and_parse_reasoning called with {} token_ids",
token_ids.len()
);
fn detect_and_parse_reasoning(&mut self, text: &str, token_ids: &[u32]) -> ParserResult {
let token_ids = if token_ids.is_empty() {
// WAR: Since we are moving to just text based reasoning parsing, converting to token_ids now using harmony encoding
let encoded_tokens = match encode_text_to_tokens(text) {
Ok(tokens) => tokens,
Err(err) => {
tracing::warn!("Failed to encode Harmony tokens: {err}");
return ParserResult::default();
}
};
&encoded_tokens.to_vec()
} else {
token_ids
};
let parser = &mut self.parser;
......@@ -153,10 +169,19 @@ impl ReasoningParser for GptOssReasoningParser {
text: &str,
token_ids: &[u32],
) -> ParserResult {
tracing::debug!(
"parse_reasoning_streaming_incremental called with {} token_ids",
token_ids.len()
);
let token_ids = if token_ids.is_empty() {
// WAR: Since we are moving to just text based reasoning parsing, converting to token_ids now using harmony encoding
let encoded_tokens = match encode_text_to_tokens(text) {
Ok(tokens) => tokens,
Err(err) => {
tracing::warn!("Failed to encode Harmony tokens: {err}");
return ParserResult::default();
}
};
&encoded_tokens.to_vec()
} else {
token_ids
};
let parser: &mut StreamableParser = &mut self.parser;
let mut normal_delta = String::new();
......@@ -261,7 +286,6 @@ impl ReasoningParser for GptOssReasoningParser {
tracing::warn!("Shouldn't be delta content after in channel: {}", channel);
}
}
tracing::debug!("No deltas to return, returning empty result");
ParserResult::default()
}
......@@ -274,12 +298,8 @@ mod tests {
#[test]
fn test_gpt_oss_reasoning_parser() {
let mut parser = GptOssReasoningParser::new().expect("Failed to create parser");
let enc = get_harmony_encoding()
.as_ref()
.expect("Failed to get encoding");
let text = "<|channel|>analysis<|message|>The user asks a simple factual question: capital of Brazil. The answer is Brasília. No additional explanation needed.<|end|><|start|>assistant<|channel|>final<|message|>The capital of Brazil is Brasília.";
let token_ids = enc.tokenizer().encode_with_special_tokens(text); // Example token IDs
let result = parser.detect_and_parse_reasoning("Test text", &token_ids);
let result = parser.detect_and_parse_reasoning(text, &[]);
assert!(result.normal_text == "The capital of Brazil is Brasília.");
assert!(
result.reasoning_text
......@@ -290,15 +310,17 @@ mod tests {
#[test]
fn test_gpt_oss_reasoning_parser_streaming() {
let mut parser = GptOssReasoningParser::new().expect("Failed to create parser");
let enc = get_harmony_encoding()
.as_ref()
.expect("Failed to get encoding");
let text = "<|channel|>analysis<|message|>The user asks a simple factual question: capital of Brazil. The answer is Brasília. No additional explanation needed.<|end|><|start|>assistant<|channel|>final<|message|>The capital of Brazil is Brasília.";
let token_ids = enc.tokenizer().encode_with_special_tokens(text); // Example token IDs
let chunks = vec![
"<|channel|>",
"analysis<|message|>The user asks a simple factual question: capital of Brazil.",
" The answer is Brasília. No additional explanation needed.",
"<|end|><|start|>assistant<|channel|>final<|message|>",
"The capital of Brazil is Brasília.",
];
let mut reasoning_text_incr = String::new();
let mut normal_text_incr = String::new();
for token in token_ids.iter() {
let result = parser.parse_reasoning_streaming_incremental("Test text", &[(*token)]);
for chunk in chunks {
let result = parser.parse_reasoning_streaming_incremental(chunk, &[]);
normal_text_incr.push_str(&result.normal_text);
reasoning_text_incr.push_str(&result.reasoning_text);
}
......@@ -337,15 +359,15 @@ mod tests {
"The user asks a simple factual question: capital of Brazil. The answer is Brasília. No additional explanation needed."
);
}
}
#[test]
fn test_gpt_oss_reasoning_parser_streaming_variable_length_chunks() {
#[test]
fn test_gpt_oss_reasoning_parser_streaming_variable_length_chunks() {
let text = "<|channel|>analysis<|message|>User asks: \"Hey, quick check: is everything up and running?\" We should check system health using the provided function get_system_health. Use function.<|end|><|start|>assistant<|channel|>commentary to=functions.get_system_health <|constrain|>json<|message|>{}";
let enc = get_harmony_encoding()
.as_ref()
.expect("Failed to get encoding");
let token_ids = enc.tokenizer().encode_with_special_tokens(text); // Example token IDs
let token_ids = enc.tokenizer().encode_with_special_tokens(text);
// Send token one by one
{
let mut parser = GptOssReasoningParser::new().expect("Failed to create parser");
......@@ -366,6 +388,7 @@ fn test_gpt_oss_reasoning_parser_streaming_variable_length_chunks() {
"<|channel|>commentary to=functions.get_system_health <|constrain|>json<|message|>"
);
}
// Send token in chunks (chunking obtained from actual model output)
{
let mut parser = GptOssReasoningParser::new().expect("Failed to create parser");
......@@ -384,7 +407,7 @@ fn test_gpt_oss_reasoning_parser_streaming_variable_length_chunks() {
],
vec![12083],
];
// concatenate chunk tokens and verify they match original token_ids
// Concatenate chunk tokens and verify they match original token_ids
let concatenated: Vec<u32> = chunk_tokens.iter().flatten().copied().collect();
assert_eq!(concatenated, token_ids);
......@@ -402,4 +425,5 @@ fn test_gpt_oss_reasoning_parser_streaming_variable_length_chunks() {
"<|channel|>commentary to=functions.get_system_health <|constrain|>json<|message|>"
);
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment