Unverified Commit 3036e60b authored by nachiketb-nvidia's avatar nachiketb-nvidia Committed by GitHub
Browse files

feat: add gpt oss reasoning parser through harmony (#2656)

- couple of refactors
- added a new dependency, openai-harmony
- implemented the gpt oss parser
parent b39382ba
......@@ -172,6 +172,17 @@ version = "1.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457"
[[package]]
name = "arg_enum_proc_macro"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.104",
]
[[package]]
name = "arrayref"
version = "0.3.9"
......@@ -341,6 +352,29 @@ version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
[[package]]
name = "av1-grain"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4f3efb2ca85bc610acfa917b5aaa36f3fcbebed5b3182d7f877b02531c4b80c8"
dependencies = [
"anyhow",
"arrayvec",
"log",
"nom 7.1.3",
"num-rational",
"v_frame",
]
[[package]]
name = "avif-serialize"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "47c8fbc0f831f4519fe8b810b6a7a91410ec83031b8233f730a0480029f6a23f"
dependencies = [
"arrayvec",
]
[[package]]
name = "aws-lc-rs"
version = "1.13.3"
......@@ -693,6 +727,12 @@ version = "2.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b8e56985ec62d17e9c1001dc89c88ecd7dc08e47eba5ec7c29c7b5eeecde967"
[[package]]
name = "bitstream-io"
version = "2.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6099cdc01846bc367c4e7dd630dc5966dccf36b652fae7a74e17b640411a91b2"
[[package]]
name = "blake3"
version = "1.8.2"
......@@ -753,9 +793,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "234113d19d0d7d613b40e86fb654acf958910802bcceab913a4f9e7cda03b1a4"
dependencies = [
"memchr",
"regex-automata 0.4.9",
"serde",
]
[[package]]
name = "built"
version = "0.7.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56ed6191a7e78c36abdb16ab65341eefd73d64d303fffccdbb00d51e4205967b"
[[package]]
name = "bumpalo"
version = "3.19.0"
......@@ -1991,6 +2038,8 @@ version = "0.4.1"
dependencies = [
"anyhow",
"dynamo-async-openai",
"lazy_static",
"openai-harmony",
"regex",
"serde",
"serde_json",
......@@ -2353,6 +2402,17 @@ dependencies = [
"regex",
]
[[package]]
name = "fancy-regex"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "531e46835a22af56d1e3b66f04844bed63158bc094a628bec1d321d9b4c44bf2"
dependencies = [
"bit-set 0.5.3",
"regex-automata 0.4.9",
"regex-syntax 0.8.5",
]
[[package]]
name = "fancy-regex"
version = "0.14.0"
......@@ -3130,6 +3190,12 @@ version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c"
[[package]]
name = "hex"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
[[package]]
name = "hf-hub"
version = "0.4.3"
......@@ -3548,6 +3614,9 @@ dependencies = [
"num-traits",
"png",
"qoi",
"ravif",
"rayon",
"rgb",
"tiff",
"zune-core",
"zune-jpeg",
......@@ -3563,6 +3632,12 @@ dependencies = [
"quick-error 2.0.1",
]
[[package]]
name = "imgref"
version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d0263a3d970d5c054ed9312c0057b4f3bde9c0b33836d3637361d4a9e6e7a408"
[[package]]
name = "indexmap"
version = "1.9.3"
......@@ -3571,6 +3646,7 @@ checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99"
dependencies = [
"autocfg",
"hashbrown 0.12.3",
"serde",
]
[[package]]
......@@ -3630,6 +3706,17 @@ dependencies = [
"cfg-if 1.0.1",
]
[[package]]
name = "interpolate_name"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.104",
]
[[package]]
name = "interprocess"
version = "2.2.3"
......@@ -3703,6 +3790,15 @@ dependencies = [
"either",
]
[[package]]
name = "itertools"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569"
dependencies = [
"either",
]
[[package]]
name = "itertools"
version = "0.14.0"
......@@ -3863,6 +3959,16 @@ dependencies = [
"uuid 1.17.0",
]
[[package]]
name = "libfuzzer-sys"
version = "0.4.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5037190e1f70cbeef565bd267599242926f724d3b8a9f510fd7e0b540cfa4404"
dependencies = [
"arbitrary",
"cc",
]
[[package]]
name = "libloading"
version = "0.8.8"
......@@ -3976,6 +4082,15 @@ version = "0.4.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94"
[[package]]
name = "loop9"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fae87c125b03c1d2c0150c90365d7d6bcc53fb73a9acaef207d2d065860f062"
dependencies = [
"imgref",
]
[[package]]
name = "lrtable"
version = "0.13.10"
......@@ -4093,6 +4208,16 @@ dependencies = [
"rawpointer",
]
[[package]]
name = "maybe-rayon"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ea1f30cedd69f0a2954655f7188c6a834246d2bcf1e315e2ac40c4b24dc9519"
dependencies = [
"cfg-if 1.0.1",
"rayon",
]
[[package]]
name = "memchr"
version = "2.7.5"
......@@ -4383,7 +4508,7 @@ dependencies = [
"rustc-hash 2.1.1",
"rustfft",
"safetensors 0.6.1",
"schemars",
"schemars 0.8.22",
"scraper",
"serde",
"serde-big-array",
......@@ -4702,6 +4827,12 @@ version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "610a5acd306ec67f907abe5567859a3c693fb9886eb1f012ab8f2a47bef3db51"
[[package]]
name = "noop_proc_macro"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0676bb32a98c1a483ce53e500a81ad9c3d5b3f7c920c28c24e9cb0980d0b5bc8"
[[package]]
name = "ntapi"
version = "0.4.1"
......@@ -4776,6 +4907,17 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9"
[[package]]
name = "num-derive"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.104",
]
[[package]]
name = "num-integer"
version = "0.1.46"
......@@ -4948,6 +5090,30 @@ version = "11.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e"
[[package]]
name = "openai-harmony"
version = "0.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6b7fd6b7d01a317c58d85a22b7cc314e14a2fef0dfdb93b819738d09caece16"
dependencies = [
"anyhow",
"base64 0.22.1",
"bstr",
"clap 4.5.42",
"fancy-regex 0.13.0",
"futures",
"image",
"regex",
"reqwest 0.12.22",
"rustc-hash 1.1.0",
"serde",
"serde_json",
"serde_with",
"sha1",
"sha2",
"thiserror 2.0.12",
]
[[package]]
name = "openssl"
version = "0.10.73"
......@@ -5437,6 +5603,25 @@ dependencies = [
"yansi",
]
[[package]]
name = "profiling"
version = "1.0.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3eb8486b569e12e2c32ad3e204dbaba5e4b5b216e9367044f25f1dba42341773"
dependencies = [
"profiling-procmacros",
]
[[package]]
name = "profiling-procmacros"
version = "1.0.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "52717f9a02b6965224f95ca2a81e2e0c5c43baacd28ca057577988930b6c3d5b"
dependencies = [
"quote",
"syn 2.0.104",
]
[[package]]
name = "prometheus"
version = "0.14.0"
......@@ -5777,6 +5962,56 @@ dependencies = [
"rand_core 0.9.3",
]
[[package]]
name = "rav1e"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cd87ce80a7665b1cce111f8a16c1f3929f6547ce91ade6addf4ec86a8dda5ce9"
dependencies = [
"arbitrary",
"arg_enum_proc_macro",
"arrayvec",
"av1-grain",
"bitstream-io",
"built",
"cfg-if 1.0.1",
"interpolate_name",
"itertools 0.12.1",
"libc",
"libfuzzer-sys",
"log",
"maybe-rayon",
"new_debug_unreachable",
"noop_proc_macro",
"num-derive",
"num-traits",
"once_cell",
"paste",
"profiling",
"rand 0.8.5",
"rand_chacha 0.3.1",
"simd_helpers",
"system-deps",
"thiserror 1.0.69",
"v_frame",
"wasm-bindgen",
]
[[package]]
name = "ravif"
version = "0.11.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5825c26fddd16ab9f515930d49028a630efec172e903483c94796cfe31893e6b"
dependencies = [
"avif-serialize",
"imgref",
"loop9",
"quick-error 2.0.1",
"rav1e",
"rayon",
"rgb",
]
[[package]]
name = "raw-cpuid"
version = "10.7.0"
......@@ -5873,6 +6108,26 @@ dependencies = [
"thiserror 2.0.12",
]
[[package]]
name = "ref-cast"
version = "1.0.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a0ae411dbe946a674d89546582cea4ba2bb8defac896622d6496f14c23ba5cf"
dependencies = [
"ref-cast-impl",
]
[[package]]
name = "ref-cast-impl"
version = "1.0.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1165225c21bff1f3bbce98f5a1f889949bc902d3575308cc7b0de30b4f6d27c7"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.104",
]
[[package]]
name = "regex"
version = "1.11.1"
......@@ -6025,6 +6280,12 @@ dependencies = [
"thiserror 1.0.69",
]
[[package]]
name = "rgb"
version = "0.8.52"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c6a884d2998352bb4daf0183589aec883f16a6da1f4dde84d8e2e9a5409a1ce"
[[package]]
name = "ring"
version = "0.17.14"
......@@ -6418,6 +6679,30 @@ dependencies = [
"serde_json",
]
[[package]]
name = "schemars"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4cd191f9397d57d581cddd31014772520aa448f65ef991055d7f61582c65165f"
dependencies = [
"dyn-clone",
"ref-cast",
"serde",
"serde_json",
]
[[package]]
name = "schemars"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "82d20c4491bc164fa2f6c5d44565947a52ad80b9505d8e36f8d54c27c739fcd0"
dependencies = [
"dyn-clone",
"ref-cast",
"serde",
"serde_json",
]
[[package]]
name = "schemars_derive"
version = "0.8.22"
......@@ -6663,6 +6948,38 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_with"
version = "3.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2c45cd61fefa9db6f254525d46e392b852e0e61d9a1fd36e5bd183450a556d5"
dependencies = [
"base64 0.22.1",
"chrono",
"hex",
"indexmap 1.9.3",
"indexmap 2.10.0",
"schemars 0.9.0",
"schemars 1.0.4",
"serde",
"serde_derive",
"serde_json",
"serde_with_macros",
"time",
]
[[package]]
name = "serde_with_macros"
version = "3.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "de90945e6565ce0d9a25098082ed4ee4002e047cb59892c318d66821e14bb30f"
dependencies = [
"darling 0.20.11",
"proc-macro2",
"quote",
"syn 2.0.104",
]
[[package]]
name = "serde_yaml"
version = "0.9.34+deprecated"
......@@ -6824,6 +7141,15 @@ version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe"
[[package]]
name = "simd_helpers"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95890f873bec569a0362c235787f3aca6e1e887302ba4840839bcc6459c42da6"
dependencies = [
"quote",
]
[[package]]
name = "similar"
version = "2.7.0"
......@@ -8313,6 +8639,17 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "v_frame"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "666b7727c8875d6ab5db9533418d7c764233ac9c0cff1d469aec8fa127597be2"
dependencies = [
"aligned-vec",
"num-traits",
"wasm-bindgen",
]
[[package]]
name = "validator"
version = "0.20.0"
......
......@@ -32,7 +32,8 @@ allow = [
"BSL-1.0",
"MPL-2.0",
"CDLA-Permissive-2.0",
"Zlib"
"Zlib",
"NCSA"
]
# TODO exceptions
......
This diff is collapsed.
......@@ -204,12 +204,12 @@ impl
for c in prompt.chars() {
// we are returning characters not tokens, so there will be some postprocessing overhead
tokio::time::sleep(*TOKEN_ECHO_DELAY).await;
let response = deltas.create_choice(0, Some(c.to_string()), None, None);
let response = deltas.create_choice(0, Some(c.to_string()), None, None, None);
yield Annotated{ id: Some(id.to_string()), data: Some(response), event: None, comment: None };
id += 1;
}
let response = deltas.create_choice(0, None, Some(dynamo_async_openai::types::FinishReason::Stop), None);
let response = deltas.create_choice(0, None, None, Some(dynamo_async_openai::types::FinishReason::Stop), None);
yield Annotated { id: Some(id.to_string()), data: Some(response), event: None, comment: None };
};
......
......@@ -91,6 +91,7 @@ impl DeltaGenerator {
// Reasoning parser type
// This is hardcoded for now, but can be made configurable later.
// TODO: Make parser type configurable once front-end integration is determined
// Change to GptOss to test GptOSS parser
let reasoning_parser_type = ReasoningParserType::Basic;
// Reasoning parser wrapper
......@@ -121,7 +122,7 @@ impl DeltaGenerator {
pub fn create_logprobs(
&self,
tokens: Vec<common::llm_backend::TokenType>,
token_ids: Vec<TokenIdType>,
token_ids: &[TokenIdType],
logprobs: Option<common::llm_backend::LogProbs>,
top_logprobs: Option<common::llm_backend::TopLogprobs>,
) -> Option<dynamo_async_openai::types::ChatChoiceLogprobs> {
......@@ -132,7 +133,7 @@ impl DeltaGenerator {
let toks = tokens
.into_iter()
.zip(token_ids)
.map(|(token, token_id)| (token.unwrap_or_default(), token_id))
.map(|(token, token_id)| (token.unwrap_or_default(), *token_id))
.collect::<Vec<(String, TokenIdType)>>();
let tok_lps = toks
.iter()
......@@ -183,11 +184,18 @@ impl DeltaGenerator {
})
}
fn create_reasoning_content(&mut self, text: Option<String>) -> Option<ParserResult> {
let text = text?;
fn create_reasoning_content(
&mut self,
text: &Option<String>,
token_ids: &[u32],
) -> Option<ParserResult> {
let text_ref = text.as_deref().unwrap_or("");
if text_ref.is_empty() && token_ids.is_empty() {
return None;
}
let parser_result = self
.reasoning_parser
.parse_reasoning_streaming_incremental(&text);
.parse_reasoning_streaming_incremental(text_ref, token_ids);
Some(parser_result)
}
......@@ -207,17 +215,12 @@ impl DeltaGenerator {
&mut self,
index: u32,
text: Option<String>,
reasoning_content: Option<String>,
finish_reason: Option<dynamo_async_openai::types::FinishReason>,
logprobs: Option<dynamo_async_openai::types::ChatChoiceLogprobs>,
) -> NvCreateChatCompletionStreamResponse {
let reasoning_parser_result = self.create_reasoning_content(text).unwrap_or_default();
let (normal_text, reasoning_content) = (
reasoning_parser_result.get_some_normal_text(),
reasoning_parser_result.get_some_reasoning(),
);
let delta = dynamo_async_openai::types::ChatCompletionStreamResponseDelta {
content: normal_text,
content: text,
function_call: None,
tool_calls: None,
role: if self.msg_counter == 0 {
......@@ -292,7 +295,7 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
let logprobs = self.create_logprobs(
delta.tokens,
delta.token_ids,
&delta.token_ids,
delta.log_probs,
delta.top_logprobs,
);
......@@ -318,9 +321,24 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
None => None,
};
let reasoning_parser_result = self
.create_reasoning_content(&delta.text, &delta.token_ids)
.unwrap_or_default();
let (normal_text, reasoning_content) = (
reasoning_parser_result.get_some_normal_text(),
reasoning_parser_result.get_some_reasoning(),
);
// Create the streaming response.
let index = 0;
let stream_response = self.create_choice(index, delta.text, finish_reason, logprobs);
let stream_response = self.create_choice(
index,
normal_text,
reasoning_content,
finish_reason,
logprobs,
);
Ok(stream_response)
}
......
......@@ -100,7 +100,7 @@ impl
let stream = stream! {
tokio::time::sleep(std::time::Duration::from_millis(max_tokens)).await;
for i in 0..10 {
let output = generator.create_choice(i,Some(format!("choice {i}")), None, None);
let output = generator.create_choice(i,Some(format!("choice {i}")), None, None, None);
yield Annotated::from_data(output);
}
......
......@@ -32,4 +32,6 @@ serde_json = { workspace = true }
tracing = { workspace = true }
uuid = { workspace = true }
regex = "1"
\ No newline at end of file
regex = "1"
openai-harmony = "0.0.3"
lazy_static = "1.5.0"
......@@ -33,7 +33,7 @@ impl BasicReasoningParser {
}
impl ReasoningParser for BasicReasoningParser {
fn detect_and_parse_reasoning(&self, text: &str) -> ParserResult {
fn detect_and_parse_reasoning(&mut self, text: &str, _token_ids: &[u32]) -> ParserResult {
log::debug!("detect_and_parse_reasoning called with text: {:?}", text);
let in_reasoning = self._in_reasoning || text.contains(&self.think_start_token);
......@@ -82,7 +82,11 @@ impl ReasoningParser for BasicReasoningParser {
}
}
fn parse_reasoning_streaming_incremental(&mut self, text: &str) -> ParserResult {
fn parse_reasoning_streaming_incremental(
&mut self,
text: &str,
_token_ids: &[u32],
) -> ParserResult {
// Incrementally parse the streaming text
self._buffer.push_str(text);
let mut current_text = self._buffer.to_string();
......@@ -180,26 +184,26 @@ mod tests {
#[test]
fn test_detect_and_parse_reasoning_reasoning() {
let parser =
let mut parser =
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let result =
parser.detect_and_parse_reasoning("<think>with reasoning</think> and more text.");
parser.detect_and_parse_reasoning("<think>with reasoning</think> and more text.", &[]);
assert_eq!(result.normal_text, "and more text.");
assert_eq!(result.reasoning_text, "with reasoning");
}
#[test]
fn test_detect_and_parse_reasoning_reasoning_no_reasoning() {
let parser =
let mut parser =
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let result = parser.detect_and_parse_reasoning("This is a test without reasoning.");
let result = parser.detect_and_parse_reasoning("This is a test without reasoning.", &[]);
assert_eq!(result.normal_text, "This is a test without reasoning.");
assert_eq!(result.reasoning_text, "");
}
#[test]
fn test_detect_and_parse_reasoning_reasoning_truncated_reasoning() {
let parser =
let mut parser =
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let result = parser.detect_and_parse_reasoning("<think>with truncated reasoning");
let result = parser.detect_and_parse_reasoning("<think>with truncated reasoning", &[]);
assert_eq!(result.normal_text, "");
assert_eq!(result.reasoning_text, "with truncated reasoning");
}
......@@ -208,7 +212,7 @@ mod tests {
fn test_parse_reasoning_streaming_incremental() {
let mut parser =
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let result = parser.parse_reasoning_streaming_incremental("<thi");
let result = parser.parse_reasoning_streaming_incremental("<thi", &[]);
assert_eq!(result.normal_text, "");
assert_eq!(result.reasoning_text, "");
}
......@@ -217,8 +221,10 @@ mod tests {
fn test_parse_reasoning_streaming_incremental_complete() {
let mut parser =
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let result = parser
.parse_reasoning_streaming_incremental("<think>with reasoning</think> and more text.");
let result = parser.parse_reasoning_streaming_incremental(
"<think>with reasoning</think> and more text.",
&[],
);
assert_eq!(result.normal_text, " and more text.");
assert_eq!(result.reasoning_text, "with reasoning");
}
......@@ -227,17 +233,18 @@ mod tests {
fn test_parse_reasoning_streaming_incremental_no_end_token() {
let mut parser =
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), true, true);
let result = parser.parse_reasoning_streaming_incremental("<think>with reasoning");
let result = parser.parse_reasoning_streaming_incremental("<think>with reasoning", &[]);
assert_eq!(result.normal_text, "");
assert_eq!(result.reasoning_text, "with reasoning");
}
#[test]
fn test_detect_and_parse_reasoning_multiple_reasoning_blocks() {
let parser =
let mut parser =
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let result = parser.detect_and_parse_reasoning(
"<think>first reasoning</think> middle <think>second reasoning</think> end",
&[],
);
// The current implementation only handles the first occurrence properly
assert_eq!(result.normal_text, "middle second reasoning</think> end");
......@@ -248,14 +255,14 @@ mod tests {
fn test_streaming_multiple_reasoning_blocks() {
let mut parser =
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, false);
let result1 =
parser.parse_reasoning_streaming_incremental("<think>first reasoning</think> middle");
let result1 = parser
.parse_reasoning_streaming_incremental("<think>first reasoning</think> middle", &[]);
assert_eq!(result1.normal_text, " middle");
assert_eq!(result1.reasoning_text, "first reasoning");
// Basic parser assumes only one reasoning block at a time
let result2 =
parser.parse_reasoning_streaming_incremental(" <think>second reasoning</think> end");
let result2 = parser
.parse_reasoning_streaming_incremental(" <think>second reasoning</think> end", &[]);
assert_eq!(result2.normal_text, " <think>second reasoning</think> end");
assert_eq!(result2.reasoning_text, "");
}
......@@ -266,13 +273,15 @@ mod tests {
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
// Feed partial opening tag
let result1 = parser.parse_reasoning_streaming_incremental("<th");
let result1 = parser.parse_reasoning_streaming_incremental("<th", &[]);
assert_eq!(result1.normal_text, "");
assert_eq!(result1.reasoning_text, "");
// Complete the opening tag and add content
let result2 = parser
.parse_reasoning_streaming_incremental("ink>reasoning content</think> normal text");
let result2 = parser.parse_reasoning_streaming_incremental(
"ink>reasoning content</think> normal text",
&[],
);
assert_eq!(result2.normal_text, " normal text");
assert_eq!(result2.reasoning_text, "reasoning content");
}
......@@ -283,12 +292,13 @@ mod tests {
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, false);
// Start with complete opening and partial content
let result1 = parser.parse_reasoning_streaming_incremental("<think>reasoning content</th");
let result1 =
parser.parse_reasoning_streaming_incremental("<think>reasoning content</th", &[]);
assert_eq!(result1.normal_text, "");
assert_eq!(result1.reasoning_text, "");
// Complete the closing tag
let result2 = parser.parse_reasoning_streaming_incremental("ink> normal text");
let result2 = parser.parse_reasoning_streaming_incremental("ink> normal text", &[]);
assert_eq!(result2.normal_text, " normal text");
assert_eq!(result2.reasoning_text, "reasoning content");
}
......@@ -299,22 +309,22 @@ mod tests {
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, false);
// First call - partial opening tag
let result1 = parser.parse_reasoning_streaming_incremental("<th");
let result1 = parser.parse_reasoning_streaming_incremental("<th", &[]);
assert_eq!(result1.normal_text, "");
assert_eq!(result1.reasoning_text, "");
// Second call - complete opening tag, start reasoning
let result2 = parser.parse_reasoning_streaming_incremental("ink>part1 ");
let result2 = parser.parse_reasoning_streaming_incremental("ink>part1 ", &[]);
assert_eq!(result2.normal_text, "");
assert_eq!(result2.reasoning_text, "");
// Third call - more reasoning content
let result3 = parser.parse_reasoning_streaming_incremental("part2 ");
let result3 = parser.parse_reasoning_streaming_incremental("part2 ", &[]);
assert_eq!(result3.normal_text, "");
assert_eq!(result3.reasoning_text, "");
// Fourth call - end reasoning and normal text
let result4 = parser.parse_reasoning_streaming_incremental("part3</think> normal");
let result4 = parser.parse_reasoning_streaming_incremental("part3</think> normal", &[]);
assert_eq!(result4.normal_text, " normal");
assert_eq!(result4.reasoning_text, "part1 part2 part3");
}
......@@ -325,27 +335,28 @@ mod tests {
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
// Start reasoning block
let result1 = parser.parse_reasoning_streaming_incremental("<think>reasoning ");
let result1 = parser.parse_reasoning_streaming_incremental("<think>reasoning ", &[]);
assert_eq!(result1.normal_text, "");
assert_eq!(result1.reasoning_text, "reasoning ");
// Continue streaming reasoning
let result2 = parser.parse_reasoning_streaming_incremental("content ");
let result2 = parser.parse_reasoning_streaming_incremental("content ", &[]);
assert_eq!(result2.normal_text, "");
assert_eq!(result2.reasoning_text, "content ");
// End reasoning block
let result3 = parser.parse_reasoning_streaming_incremental("more</think> normal");
let result3 = parser.parse_reasoning_streaming_incremental("more</think> normal", &[]);
assert_eq!(result3.normal_text, " normal");
assert_eq!(result3.reasoning_text, "more");
}
#[test]
fn test_nested_reasoning_blocks() {
let parser =
let mut parser =
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let result = parser.detect_and_parse_reasoning(
"<think>outer <think>inner</think> reasoning</think> normal",
&[],
);
// Current implementation should handle this by finding the first closing tag
assert_eq!(result.normal_text, "reasoning</think> normal");
......@@ -355,28 +366,28 @@ mod tests {
#[test]
fn test_malformed_missing_closing_tag() {
let parser =
let mut parser =
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let result = parser.detect_and_parse_reasoning("<think>reasoning without closing tag");
let result = parser.detect_and_parse_reasoning("<think>reasoning without closing tag", &[]);
assert_eq!(result.normal_text, "");
assert_eq!(result.reasoning_text, "reasoning without closing tag");
}
#[test]
fn test_malformed_stray_closing_tag() {
let parser =
let mut parser =
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let result = parser.detect_and_parse_reasoning("normal text</think> more normal");
let result = parser.detect_and_parse_reasoning("normal text</think> more normal", &[]);
assert_eq!(result.normal_text, "normal text</think> more normal");
assert_eq!(result.reasoning_text, "");
}
#[test]
fn test_malformed_multiple_opening_tags() {
let parser =
let mut parser =
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let result = parser
.detect_and_parse_reasoning("<think>first <think>second reasoning</think> normal");
.detect_and_parse_reasoning("<think>first <think>second reasoning</think> normal", &[]);
// Should handle by replacing all opening tags and using first closing tag
assert_eq!(result.normal_text, "normal");
assert_eq!(result.reasoning_text, "first second reasoning");
......@@ -384,27 +395,27 @@ mod tests {
#[test]
fn test_empty_reasoning_block() {
let parser =
let mut parser =
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let result = parser.detect_and_parse_reasoning("<think></think> normal text");
let result = parser.detect_and_parse_reasoning("<think></think> normal text", &[]);
assert_eq!(result.normal_text, "normal text");
assert_eq!(result.reasoning_text, "");
}
#[test]
fn test_whitespace_only_reasoning_block() {
let parser =
let mut parser =
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let result = parser.detect_and_parse_reasoning("<think> \n\t </think> normal text");
let result = parser.detect_and_parse_reasoning("<think> \n\t </think> normal text", &[]);
assert_eq!(result.normal_text, "normal text");
assert_eq!(result.reasoning_text, ""); // Should be empty after trim
}
#[test]
fn test_force_reasoning_mode() {
let parser =
let mut parser =
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), true, true);
let result = parser.detect_and_parse_reasoning("no think tags here");
let result = parser.detect_and_parse_reasoning("no think tags here", &[]);
assert_eq!(result.normal_text, "");
assert_eq!(result.reasoning_text, "no think tags here");
}
......@@ -416,19 +427,19 @@ mod tests {
// Process complete reasoning block
let result1 =
parser.parse_reasoning_streaming_incremental("<think>reasoning</think> normal");
parser.parse_reasoning_streaming_incremental("<think>reasoning</think> normal", &[]);
assert_eq!(result1.normal_text, " normal");
assert_eq!(result1.reasoning_text, "reasoning");
// Process normal text - should not be affected by previous state
let result2 = parser.parse_reasoning_streaming_incremental(" more normal text");
let result2 = parser.parse_reasoning_streaming_incremental(" more normal text", &[]);
assert_eq!(result2.normal_text, " more normal text");
assert_eq!(result2.reasoning_text, "");
// Basic parser does not expect more than one reasoning block at a time
// So this should not affect the state
let result3 =
parser.parse_reasoning_streaming_incremental(" <think>new reasoning</think> final");
let result3 = parser
.parse_reasoning_streaming_incremental(" <think>new reasoning</think> final", &[]);
assert_eq!(result3.normal_text, " <think>new reasoning</think> final");
assert_eq!(result3.reasoning_text, "");
}
......
......@@ -24,11 +24,16 @@ impl DeepseekR1ReasoningParser {
}
impl ReasoningParser for DeepseekR1ReasoningParser {
fn parse_reasoning_streaming_incremental(&mut self, text: &str) -> ParserResult {
self.base.parse_reasoning_streaming_incremental(text)
fn parse_reasoning_streaming_incremental(
&mut self,
text: &str,
token_ids: &[u32],
) -> ParserResult {
self.base
.parse_reasoning_streaming_incremental(text, token_ids)
}
fn detect_and_parse_reasoning(&self, text: &str) -> ParserResult {
self.base.detect_and_parse_reasoning(text)
fn detect_and_parse_reasoning(&mut self, text: &str, token_ids: &[u32]) -> ParserResult {
self.base.detect_and_parse_reasoning(text, token_ids)
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::fmt::Debug;
use crate::ParserResult;
use crate::ReasoningParser;
use openai_harmony::StreamableParser;
use openai_harmony::chat::TextContent;
use openai_harmony::{HarmonyEncoding, HarmonyEncodingName, chat::Role, load_harmony_encoding};
///// Static initialization of harmony encoder to not affect performance every time a parser is created
/// This is because load_harmony_encoding downloads some tiktoken files into a directory and we don't want to do this every time we create a parser.
use std::sync::OnceLock;
static GLOBAL_HARMONY_GPTOSS_ENCODING: OnceLock<Result<HarmonyEncoding, anyhow::Error>> =
OnceLock::new();
fn get_harmony_encoding() -> &'static Result<HarmonyEncoding, anyhow::Error> {
GLOBAL_HARMONY_GPTOSS_ENCODING
.get_or_init(|| load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss))
}
pub struct GptOssReasoningParser {
parser: StreamableParser,
}
/// Implement Debug for GptOssReasoningParser separately because StreamableParser does not implement Debug
impl Debug for GptOssReasoningParser {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GptOssReasoningParser")
.field("parser", &self.parser.state_json())
.finish()
}
}
impl GptOssReasoningParser {
pub fn new() -> anyhow::Result<Self> {
let parser = match get_harmony_encoding().as_ref() {
Ok(enc) => match StreamableParser::new(enc.clone(), Some(Role::Assistant)) {
Ok(p) => p,
Err(e) => {
tracing::warn!("Harmony StreamableParser init failed for GPT OSS: {e}");
return Err(anyhow::anyhow!(
"Failed to load Harmony StreamableParser: {e}"
));
}
},
Err(e) => {
tracing::warn!("Failed to load Harmony encoding for GPT OSS: {e}");
return Err(anyhow::anyhow!("Failed to load Harmony encoding: {e}"));
}
};
Ok(Self { parser })
}
}
impl ReasoningParser for GptOssReasoningParser {
fn detect_and_parse_reasoning(&mut self, _text: &str, token_ids: &[u32]) -> ParserResult {
tracing::debug!(
"detect_and_parse_reasoning called with {} token_ids",
token_ids.len()
);
let parser = &mut self.parser;
for (i, token_id) in token_ids.iter().enumerate() {
tracing::debug!(
"Processing token {} of {}: {}",
i + 1,
token_ids.len(),
token_id
);
if let Err(e) = parser.process(*token_id) {
tracing::warn!("Harmony parse error for token_id {token_id}: {e}");
return ParserResult::default();
}
}
let output_msgs = parser.messages();
tracing::debug!("Parser has {} output messages", output_msgs.len());
match output_msgs.len() {
0 => {
tracing::debug!("No output messages, using current content");
let current = parser.current_content().unwrap_or_default();
tracing::debug!("Current content length: {}", current.len());
ParserResult {
normal_text: String::new(),
reasoning_text: current,
}
}
1 => {
tracing::debug!("Single output message detected");
let mut reasoning_text = String::new();
if let Some(openai_harmony::chat::Content::Text(TextContent { text })) =
output_msgs[0].content.first()
{
reasoning_text.push_str(text);
tracing::debug!("Extracted reasoning text length: {}", reasoning_text.len());
}
let current = parser.current_content().unwrap_or_default();
tracing::debug!("Current content length: {}", current.len());
ParserResult {
normal_text: current,
reasoning_text,
}
}
_ => {
tracing::debug!("Multiple output messages detected: {}", output_msgs.len());
let mut reasoning_text = String::new();
let mut normal_text = String::new();
// Loop until second last message
for (i, parse_msg) in output_msgs.iter().take(output_msgs.len() - 1).enumerate() {
tracing::debug!("Processing reasoning message {}", i + 1);
if let Some(openai_harmony::chat::Content::Text(TextContent { text })) =
parse_msg.content.first()
{
reasoning_text.push_str(text);
tracing::debug!("Added {} chars to reasoning text", text.len());
}
}
let last_msg = &output_msgs[output_msgs.len() - 1];
tracing::debug!("Processing final message");
// Handle the last message
if let Some(openai_harmony::chat::Content::Text(TextContent { text })) =
last_msg.content.first()
{
normal_text.push_str(text);
tracing::debug!("Added {} chars to normal text", text.len());
}
tracing::debug!(
"Final result - normal_text: {} chars, reasoning_text: {} chars",
normal_text.len(),
reasoning_text.len()
);
ParserResult {
normal_text,
reasoning_text,
}
}
}
}
fn parse_reasoning_streaming_incremental(
&mut self,
_text: &str,
token_ids: &[u32],
) -> ParserResult {
tracing::debug!(
"parse_reasoning_streaming_incremental called with {} token_ids",
token_ids.len()
);
let parser: &mut StreamableParser = &mut self.parser;
for (i, token_id) in token_ids.iter().enumerate() {
tracing::debug!(
"Processing streaming token {} of {}: {}",
i + 1,
token_ids.len(),
token_id
);
if let Err(e) = parser.process(*token_id) {
tracing::warn!("Harmony parse error for token_id {token_id}: {e}");
return ParserResult::default();
}
}
if let Some(channel) = self.parser.current_channel() {
tracing::debug!("Current channel: {}", channel);
if channel == "final" {
tracing::debug!("In final channel, processing normal text");
// If we're in the final channel, we should not parse reasoning
if let Some(current) = self.parser.last_content_delta().unwrap_or_default() {
tracing::debug!("Got normal text delta of {} chars", current.len());
return ParserResult {
normal_text: current,
reasoning_text: String::new(),
};
}
tracing::debug!("No content delta in final channel");
ParserResult::default()
} else {
tracing::debug!("In reasoning channel: {}", channel);
if let Some(current) = self.parser.last_content_delta().unwrap_or_default() {
tracing::debug!("Got reasoning text delta of {} chars", current.len());
return ParserResult {
normal_text: String::new(),
reasoning_text: current,
};
}
tracing::debug!("No content delta in reasoning channel");
ParserResult::default()
}
} else {
tracing::debug!("No current channel detected");
ParserResult::default()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gpt_oss_reasoning_parser() {
let mut parser = GptOssReasoningParser::new().expect("Failed to create parser");
let enc = get_harmony_encoding()
.as_ref()
.expect("Failed to get encoding");
let text = "<|channel|>analysis<|message|>The user asks a simple factual question: capital of Brazil. The answer is Brasília. No additional explanation needed.<|end|><|start|>assistant<|channel|>final<|message|>The capital of Brazil is Brasília.";
let token_ids = enc.tokenizer().encode_with_special_tokens(text); // Example token IDs
let result = parser.detect_and_parse_reasoning("Test text", &token_ids);
assert!(result.normal_text == "The capital of Brazil is Brasília.");
assert!(
result.reasoning_text
== "The user asks a simple factual question: capital of Brazil. The answer is Brasília. No additional explanation needed."
);
}
#[test]
fn test_gpt_oss_reasoning_parser_streaming() {
let mut parser = GptOssReasoningParser::new().expect("Failed to create parser");
let enc = get_harmony_encoding()
.as_ref()
.expect("Failed to get encoding");
let text = "<|channel|>analysis<|message|>The user asks a simple factual question: capital of Brazil. The answer is Brasília. No additional explanation needed.<|end|><|start|>assistant<|channel|>final<|message|>The capital of Brazil is Brasília.";
let token_ids = enc.tokenizer().encode_with_special_tokens(text); // Example token IDs
let mut reasoning_text_incr = String::new();
let mut normal_text_incr = String::new();
for token in token_ids.iter() {
let result = parser.parse_reasoning_streaming_incremental("Test text", &[(*token)]);
normal_text_incr.push_str(&result.normal_text);
reasoning_text_incr.push_str(&result.reasoning_text);
}
assert!(normal_text_incr == "The capital of Brazil is Brasília.");
assert!(
reasoning_text_incr
== "The user asks a simple factual question: capital of Brazil. The answer is Brasília. No additional explanation needed."
);
}
}
......@@ -3,10 +3,12 @@
mod base_parser;
mod deepseek_r1_parser;
mod gpt_oss_parser;
// Re-export main types and functions for convenience
pub use base_parser::BasicReasoningParser;
pub use deepseek_r1_parser::DeepseekR1ReasoningParser;
pub use gpt_oss_parser::GptOssReasoningParser;
#[derive(Debug, Clone, Default)]
pub struct ParserResult {
......@@ -39,12 +41,16 @@ pub trait ReasoningParser: Send + std::fmt::Debug {
/// Parses a standalone, non-streaming input chunk. Implementations may reset or ignore
/// internal streaming state and should return the split of normal vs reasoning text for
/// this complete input. Marker tokens must not be included in either output.
fn detect_and_parse_reasoning(&self, text: &str) -> ParserResult;
fn detect_and_parse_reasoning(&mut self, text: &str, token_ids: &[u32]) -> ParserResult;
/// Parses a streaming chunk and updates internal state. The return value should be the
/// delta: only the newly discovered normal and reasoning text attributable to this chunk
/// (not the cumulative totals). Marker tokens must not be included in either output.
fn parse_reasoning_streaming_incremental(&mut self, text: &str) -> ParserResult;
fn parse_reasoning_streaming_incremental(
&mut self,
text: &str,
token_ids: &[u32],
) -> ParserResult;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
......@@ -52,6 +58,7 @@ pub trait ReasoningParser: Send + std::fmt::Debug {
pub enum ReasoningParserType {
DeepseekR1,
Basic,
GptOss,
}
#[derive(std::fmt::Debug)]
......@@ -60,12 +67,17 @@ pub struct ReasoningParserWrapper {
}
impl ReasoningParser for ReasoningParserWrapper {
fn detect_and_parse_reasoning(&self, text: &str) -> ParserResult {
self.parser.detect_and_parse_reasoning(text)
fn detect_and_parse_reasoning(&mut self, text: &str, token_ids: &[u32]) -> ParserResult {
self.parser.detect_and_parse_reasoning(text, token_ids)
}
fn parse_reasoning_streaming_incremental(&mut self, text: &str) -> ParserResult {
self.parser.parse_reasoning_streaming_incremental(text)
fn parse_reasoning_streaming_incremental(
&mut self,
text: &str,
token_ids: &[u32],
) -> ParserResult {
self.parser
.parse_reasoning_streaming_incremental(text, token_ids)
}
}
......@@ -83,6 +95,24 @@ impl ReasoningParserType {
true,
)),
},
ReasoningParserType::GptOss => match GptOssReasoningParser::new() {
Ok(parser) => ReasoningParserWrapper {
parser: Box::new(parser),
},
Err(e) => {
tracing::warn!(
"GptOssReasoningParser could not be initialized, falling back to Basic Reasoning Parser: {e}"
);
ReasoningParserWrapper {
parser: Box::new(BasicReasoningParser::new(
"<think>".into(),
"</think>".into(),
false,
true,
)),
}
}
},
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment