"docs/vscode:/vscode.git/clone" did not exist on "b24f78349c60ebf2e64fa7f9ae396e1f87c22033"
Unverified Commit a0a77d93 authored by Jonas's avatar Jonas Committed by GitHub
Browse files

Fix Harmony reasoning parser for and auto-separation for gpt-oss models (#9190)


Co-authored-by: default avatarChang Su <chang.s.su@oracle.com>
Co-authored-by: default avatarChayenne <zhaochen20@outlook.com>
Co-authored-by: default avatarzhaochenyang20 <zhaochenyang20@gmail.com>
Co-authored-by: default avatarminleminzui <2969413251@qq.com>
Co-authored-by: default avatarmaocheng23 <maocheng@berkeley.edu>
Co-authored-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
parent 24a8cee6
...@@ -148,6 +148,16 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -148,6 +148,16 @@ class OpenAIServingChat(OpenAIServingBase):
self, request: ChatCompletionRequest, is_multimodal: bool self, request: ChatCompletionRequest, is_multimodal: bool
) -> MessageProcessingResult: ) -> MessageProcessingResult:
"""Process chat messages and apply chat template""" """Process chat messages and apply chat template"""
is_gpt_oss = (
hasattr(self.tokenizer_manager.model_config, "hf_config")
and hasattr(self.tokenizer_manager.model_config.hf_config, "model_type")
and self.tokenizer_manager.model_config.hf_config.model_type == "gpt_oss"
)
# GptOss model needs to keep special tokens for harmony parsing
if is_gpt_oss:
request.skip_special_tokens = False
tool_call_constraint = None tool_call_constraint = None
# Apply chat template and its stop strings # Apply chat template and its stop strings
......
This diff is collapsed.
...@@ -106,6 +106,8 @@ class DetokenizerManager: ...@@ -106,6 +106,8 @@ class DetokenizerManager:
] ]
) )
self.is_tool_call_parser_gpt_oss = server_args.tool_call_parser == "gpt-oss"
def event_loop(self): def event_loop(self):
"""The event loop that handles requests""" """The event loop that handles requests"""
while True: while True:
...@@ -133,6 +135,9 @@ class DetokenizerManager: ...@@ -133,6 +135,9 @@ class DetokenizerManager:
# Trim stop token. # Trim stop token.
if isinstance(matched, int) and isinstance(output, list): if isinstance(matched, int) and isinstance(output, list):
# 200012 <|call|> is the tool call token and one of eos tokens for gpt-oss model
if output[-1] == 200012 and self.is_tool_call_parser_gpt_oss:
return output
assert len(output) > 0 assert len(output) > 0
return output[:-1] return output[:-1]
return output return output
......
import re import re
from typing import Dict, Optional, Tuple, Type from typing import Dict, Optional, Tuple, Type
from sglang.srt.harmony_parser import HarmonyParser
class StreamingParseResult: class StreamingParseResult:
"""Result of streaming incremental parsing.""" """Result of streaming incremental parsing."""
def __init__(self, normal_text: str = "", reasoning_text: str = ""): def __init__(
self.normal_text = normal_text self,
self.reasoning_text = reasoning_text normal_text: Optional[str] = None,
reasoning_text: Optional[str] = None,
):
self.normal_text = normal_text or ""
self.reasoning_text = reasoning_text or ""
class BaseReasoningFormatDetector: class BaseReasoningFormatDetector:
...@@ -188,316 +194,60 @@ class KimiDetector(BaseReasoningFormatDetector): ...@@ -188,316 +194,60 @@ class KimiDetector(BaseReasoningFormatDetector):
class GptOssDetector(BaseReasoningFormatDetector): class GptOssDetector(BaseReasoningFormatDetector):
""" """
Detector for T4-style reasoning format. Detector for T4-style reasoning format (GPT-OSS), using the HarmonyParser.
Assumes reasoning format with two channels:
<|channel|>analysis<|message|>...reasoning content...<|end|>
<|start|>assistant<|channel|>final<|message|>...final answer...<|return|>
Returns content from 'analysis' channel as reasoning_text
and content from 'final' channel as normal_text.
Args:
stream_reasoning (bool): If False, accumulates reasoning content until complete.
If True, streams reasoning content as it arrives.
""" """
def __init__(self, stream_reasoning: bool = True, force_reasoning: bool = True): def __init__(self, stream_reasoning: bool = True, force_reasoning: bool = True):
# TypeScript uses channel tokens instead of simple start/end tokens
super().__init__( super().__init__(
"<|channel|>analysis<|message|>", "<|channel|>analysis<|message|>",
"<|end|>", "<|end|>",
force_reasoning=True, force_reasoning=force_reasoning,
stream_reasoning=stream_reasoning, stream_reasoning=stream_reasoning,
) )
self.final_channel_start = "<|start|>assistant<|channel|>final<|message|>" self.parser = HarmonyParser()
self.final_channel_end = "<|return|>"
self._in_final_channel = False
self._analysis_complete = False
self._in_reasoning = True
def detect_and_parse(self, text: str) -> StreamingParseResult: def detect_and_parse(self, text: str) -> StreamingParseResult:
""" events = self.parser.parse(text)
One-time parsing: Detects and parses both analysis and final channels. # Flush the buffer for one-shot parsing
Tool call channels are preserved in normal_text for downstream processing. events += self.parser.parse("")
HACK: Also handles simplified format where text starts with "analysis" and transitions reasoning_text = "".join(
to "assistantfinal" without full channel markers. [e.content for e in events if e.event_type == "reasoning"]
"""
# HACK: Handle simplified format (analysis...assistantfinal) without channel markers
if (
text.startswith("analysis")
and "assistantfinal" in text
and "<|channel|>" not in text
):
# Split on "assistantfinal"
parts = text.split("assistantfinal", 1)
self._in_reasoning = False
if len(parts) == 2:
reasoning_text = parts[0][
len("analysis") :
].strip() # Remove "analysis" prefix
normal_text = parts[1].strip()
return StreamingParseResult(
normal_text=normal_text, reasoning_text=reasoning_text
)
reasoning_parts = []
normal_parts = []
current_pos = 0
# Process text sequentially to preserve tool calls between analysis sections
while current_pos < len(text):
# Look for next analysis channel
analysis_start_idx = text.find(self.think_start_token, current_pos)
if analysis_start_idx == -1:
# No more analysis channels, rest goes to remaining
break
# Preserve any content before this analysis channel (could include tool calls)
if analysis_start_idx > current_pos:
between_content = text[current_pos:analysis_start_idx]
# This content will be added to normal_parts later
normal_parts.append(between_content)
# Extract analysis content
analysis_content_start = analysis_start_idx + len(self.think_start_token)
analysis_end_idx = text.find(self.think_end_token, analysis_content_start)
if analysis_end_idx != -1:
reasoning_parts.append(
text[analysis_content_start:analysis_end_idx].strip()
)
current_pos = analysis_end_idx + len(self.think_end_token)
else:
# Analysis not complete
reasoning_parts.append(text[analysis_content_start:].strip())
reasoning_text = "".join(reasoning_parts)
return StreamingParseResult(reasoning_text=reasoning_text)
# Add any remaining text after all analysis sections
if current_pos < len(text):
remaining = text[current_pos:]
normal_parts.append(remaining)
# Process non-analysis content for commentary sections
full_normal_text = "".join(normal_parts)
# Extract reasoning from non-tool-call commentary sections
# Tool calls have "to=" in their header, regular commentary does not
commentary_pattern = re.compile(
r"<\|start\|>assistant<\|channel\|>commentary<\|message\|>(.*?)(?:<\|end\|>|<\|call\|>)",
re.DOTALL,
) )
normal_parts = []
cleaned_text = full_normal_text for e in events:
for match in reversed(list(commentary_pattern.finditer(full_normal_text))): if e.event_type == "normal":
# Check if this commentary is a tool call by looking at the text before <|message|> normal_parts.append(e.content)
match_start = match.start() elif e.event_type == "tool_call":
# Find where "<|channel|>commentary" starts within the matched pattern # Use raw_text to preserve structural markers for function call detector
# The pattern starts with "<|start|>assistant<|channel|>commentary" normal_parts.append(e.raw_text if e.raw_text else e.content)
# So we look for the text between "commentary" and "<|message|>" in the match normal_text = "".join(normal_parts)
match_text = full_normal_text[match_start : match.end()] # Tool call events preserve raw text with structural markers
commentary_idx = match_text.find("<|channel|>commentary")
if commentary_idx != -1:
message_idx = match_text.find("<|message|>", commentary_idx)
if message_idx != -1:
between_text = match_text[commentary_idx:message_idx]
# If no "to=" found, this is regular commentary (reasoning content)
if " to=" not in between_text:
content = match.group(1).strip()
reasoning_parts.append(content)
# Remove this commentary section from normal text
cleaned_text = (
cleaned_text[: match.start()] + cleaned_text[match.end() :]
)
full_normal_text = cleaned_text
# Combine all reasoning parts
reasoning_text = "".join(reasoning_parts)
# Process full_normal_text for final output
normal_text = ""
if self.final_channel_start in full_normal_text:
final_start = full_normal_text.find(self.final_channel_start)
final_content_start = final_start + len(self.final_channel_start)
final_end = full_normal_text.find(
self.final_channel_end, final_content_start
)
if final_end != -1:
# Extract content before final channel (includes tool calls)
before_final = full_normal_text[:final_start].strip()
# Extract ONLY the final channel content (not the channel markers)
final_text = full_normal_text[final_content_start:final_end].strip()
# Extract content after final channel
after_final = full_normal_text[
final_end + len(self.final_channel_end) :
].strip()
# For tool calls + final answer: concatenate tool calls with final text
parts = []
if before_final:
parts.append(before_final)
if final_text:
parts.append(final_text)
if after_final:
parts.append(after_final)
normal_text = " ".join(parts)
else:
# Final channel not complete - extract what we have
# Look for just <|channel|>final<|message|> without <|return|>
alt_final_start = full_normal_text.find("<|channel|>final<|message|>")
if alt_final_start != -1:
before_alt_final = full_normal_text[:alt_final_start].strip()
alt_final_content = full_normal_text[
alt_final_start + len("<|channel|>final<|message|>") :
].strip()
parts = []
if before_alt_final:
parts.append(before_alt_final)
if alt_final_content:
parts.append(alt_final_content)
normal_text = " ".join(parts)
else:
normal_text = full_normal_text.strip()
else:
# No final channel, treat all as normal text (includes tool calls)
normal_text = full_normal_text.strip()
return StreamingParseResult( return StreamingParseResult(
normal_text=normal_text, reasoning_text=reasoning_text normal_text=normal_text,
reasoning_text=reasoning_text,
) )
def parse_streaming_increment(self, new_text: str) -> StreamingParseResult: def parse_streaming_increment(self, new_text: str) -> StreamingParseResult:
""" events = self.parser.parse(new_text)
Streaming incremental parsing for GPT-OSS format.
This is a simplified streaming implementation that accumulates content reasoning_text = "".join(
and delegates to the non-streaming parser for complex multi-channel parsing. [e.content for e in events if e.event_type == "reasoning"]
TODO: Implement proper incremental parsing for better streaming performance.
"""
self._buffer += new_text
if not self._in_reasoning:
return StreamingParseResult(normal_text=new_text)
# Check if we have complete sections to process
# For GPT-OSS, we need to wait for complete channel sections
# HACK: For now, use simplified approach - wait for key markers before processing
key_markers = ["<|end|>", "<|call|>", "<|return|>", "assistantfinal"]
has_complete_section = any(marker in self._buffer for marker in key_markers)
if not has_complete_section:
# Still accumulating, don't process yet
return StreamingParseResult()
# Handle simplified format (analysis...assistantfinal) with true incremental streaming
if (
"<|channel|>" not in self._buffer
): # Simplified format without channel markers
if self._buffer.startswith("analysis"):
# Check if we have the transition to assistantfinal
if "assistantfinal" in self._buffer:
self._in_reasoning = False
# Complete reasoning section - extract and stream it
parts = self._buffer.split("assistantfinal", 1)
reasoning_text = parts[0][len("analysis") :].strip()
final_content = parts[1].strip()
# Clear buffer and return both reasoning and final content
self._buffer = ""
return StreamingParseResult(
reasoning_text=reasoning_text if self.stream_reasoning else "",
normal_text=final_content,
)
elif self.stream_reasoning:
# Stream reasoning content incrementally as it arrives
current_reasoning = self._buffer[len("analysis") :].strip()
self._buffer = ""
return StreamingParseResult(reasoning_text=current_reasoning)
else:
# Wait for assistantfinal
return StreamingParseResult()
elif self._buffer.startswith("assistantfinal"):
# Direct final content without analysis
final_content = self._buffer[len("assistantfinal") :].strip()
self._buffer = ""
return StreamingParseResult(normal_text=final_content)
# For full channel format, process sections as they complete
result = StreamingParseResult()
# Process complete analysis sections
while (
self.think_start_token in self._buffer
and self.think_end_token in self._buffer
):
start_idx = self._buffer.find(self.think_start_token)
start_pos = start_idx + len(self.think_start_token)
end_pos = self._buffer.find(self.think_end_token, start_pos)
if end_pos != -1:
reasoning_content = self._buffer[start_pos:end_pos].strip()
if self.stream_reasoning and reasoning_content:
result.reasoning_text += reasoning_content
# Remove processed analysis section
self._buffer = (
self._buffer[:start_idx]
+ self._buffer[end_pos + len(self.think_end_token) :]
)
else:
break
# Process complete commentary sections
commentary_pattern = re.compile(
r"<\|start\|>assistant<\|channel\|>commentary<\|message\|>(.*?)(?:<\|end\|>|<\|call\|>)",
re.DOTALL,
) )
normal_parts = []
for e in events:
if e.event_type == "normal":
normal_parts.append(e.content)
elif e.event_type == "tool_call":
# Use raw_text to preserve structural markers for function call detector
normal_parts.append(e.raw_text if e.raw_text else e.content)
normal_text = "".join(normal_parts)
for match in reversed(list(commentary_pattern.finditer(self._buffer))): return StreamingParseResult(
# Check if this is a tool call normal_text=normal_text,
start_pos = match.start() reasoning_text=reasoning_text,
commentary_content = match.group(1).strip() )
if self.stream_reasoning and commentary_content:
result.reasoning_text += commentary_content
# Remove this commentary section
self._buffer = self._buffer[: match.start()] + self._buffer[match.end() :]
# Clean up any standalone <|start|>assistant
self._buffer = re.sub(
r"<\|start\|>assistant(?=<\|start\|>assistant)", "", self._buffer
)
# Handle final channel completion
if self.final_channel_start in self._buffer:
final_start = self._buffer.find(self.final_channel_start)
final_content_start = final_start + len(self.final_channel_start)
# Check if final channel is complete
final_end = self._buffer.find(self.final_channel_end, final_content_start)
if final_end != -1:
# Complete final channel - process everything
final_result = self.detect_and_parse(self._buffer)
self._buffer = ""
return StreamingParseResult(
normal_text=final_result.normal_text,
reasoning_text=result.reasoning_text + final_result.reasoning_text,
)
else:
# Extract content before final channel (e.g. tool calls)
before_final = self._buffer[:final_start]
if before_final:
# Output tool calls for processing
result.normal_text += before_final
# Keep the final channel part in buffer
self._buffer = self._buffer[final_start:]
return result
class ReasoningParser: class ReasoningParser:
...@@ -526,7 +276,7 @@ class ReasoningParser: ...@@ -526,7 +276,7 @@ class ReasoningParser:
self, self,
model_type: Optional[str] = None, model_type: Optional[str] = None,
stream_reasoning: bool = True, stream_reasoning: bool = True,
force_reasoning: bool = False, force_reasoning: Optional[bool] = None,
): ):
if not model_type: if not model_type:
raise ValueError("Model type must be specified") raise ValueError("Model type must be specified")
...@@ -535,19 +285,25 @@ class ReasoningParser: ...@@ -535,19 +285,25 @@ class ReasoningParser:
if not detector_class: if not detector_class:
raise ValueError(f"Unsupported model type: {model_type}") raise ValueError(f"Unsupported model type: {model_type}")
if model_type.lower() == "qwen3-thinking": # Special cases where we override force_reasoning
if model_type.lower() in {"qwen3-thinking", "gpt-oss"}:
force_reasoning = True force_reasoning = True
self.detector = detector_class( # Only pass force_reasoning if explicitly set, let detectors use their defaults
stream_reasoning=stream_reasoning, force_reasoning=force_reasoning kwargs = {"stream_reasoning": stream_reasoning}
) if force_reasoning is not None:
kwargs["force_reasoning"] = force_reasoning
self.detector = detector_class(**kwargs)
def parse_non_stream(self, full_text: str) -> Tuple[str, str]: def parse_non_stream(self, full_text: str) -> Tuple[Optional[str], Optional[str]]:
"""Non-streaming call: one-time parsing""" """Non-streaming call: one-time parsing"""
ret = self.detector.detect_and_parse(full_text) ret = self.detector.detect_and_parse(full_text)
return ret.reasoning_text, ret.normal_text return ret.reasoning_text, ret.normal_text
def parse_stream_chunk(self, chunk_text: str) -> Tuple[str, str]: def parse_stream_chunk(
self, chunk_text: str
) -> Tuple[Optional[str], Optional[str]]:
"""Streaming call: incremental parsing""" """Streaming call: incremental parsing"""
ret = self.detector.parse_streaming_increment(chunk_text) ret = self.detector.parse_streaming_increment(chunk_text)
return ret.reasoning_text, ret.normal_text return ret.reasoning_text, ret.normal_text
...@@ -2271,6 +2271,7 @@ class ServerArgs: ...@@ -2271,6 +2271,7 @@ class ServerArgs:
if is_mxfp4_quant_format: if is_mxfp4_quant_format:
# use bf16 for mxfp4 triton kernels # use bf16 for mxfp4 triton kernels
self.dtype = "bfloat16" self.dtype = "bfloat16"
elif "Llama4" in model_arch: elif "Llama4" in model_arch:
assert self.attention_backend in { assert self.attention_backend in {
"fa3", "fa3",
......
...@@ -73,6 +73,7 @@ suites = { ...@@ -73,6 +73,7 @@ suites = {
TestFile("test_function_call_parser.py", 10), TestFile("test_function_call_parser.py", 10),
TestFile("test_fused_moe.py", 30), TestFile("test_fused_moe.py", 30),
TestFile("test_gpt_oss_1gpu.py", 600), TestFile("test_gpt_oss_1gpu.py", 600),
TestFile("test_harmony_parser.py", 20),
TestFile("test_hidden_states.py", 55), TestFile("test_hidden_states.py", 55),
TestFile("test_hybrid_attn_backend.py", 100), TestFile("test_hybrid_attn_backend.py", 100),
TestFile("test_input_embeddings.py", 38), TestFile("test_input_embeddings.py", 38),
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment