Unverified Commit a2184901 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

(gpt-oss, oai, chat): Remove Harmony Integration and Implement Native GPT-OSS...

(gpt-oss, oai, chat): Remove Harmony Integration and Implement Native GPT-OSS Tool Call Support (#9043)
parent 0eec4cb6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from vLLM: https://github.com/vllm-project/vllm/blob/1b9902806915040ac9b3029f2ab7522ec505afc3/vllm/entrypoints/harmony_utils.py
# Slight differences in processing chat messages
import datetime
import json
from collections.abc import Iterable
......
......@@ -174,7 +174,6 @@ async def lifespan(fast_api_app: FastAPI):
tool_server=tool_server,
)
except Exception as e:
# print stack trace
import traceback
traceback.print_exc()
......
......@@ -859,15 +859,6 @@ class ResponseReasoningTextContent(BaseModel):
type: Literal["reasoning_text"] = "reasoning_text"
class ResponseReasoningItem(BaseModel):
id: str
content: list[ResponseReasoningTextContent] = Field(default_factory=list)
summary: list = Field(default_factory=list)
type: Literal["reasoning"] = "reasoning"
encrypted_content: Optional[str] = None
status: Optional[Literal["in_progress", "completed", "incomplete"]]
ResponseInputOutputItem: TypeAlias = Union[
ResponseInputItemParam, "ResponseReasoningItem", ResponseFunctionToolCall
]
......@@ -7,18 +7,8 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from fastapi import Request
from fastapi.responses import ORJSONResponse, StreamingResponse
from openai_harmony import Message as OpenAIMessage
from sglang.srt.conversation import generate_chat_conv
from sglang.srt.entrypoints.harmony_utils import (
get_developer_message,
get_stop_tokens_for_assistant_actions,
get_streamable_parser_for_assistant,
get_system_message,
parse_chat_input,
parse_output_into_messages,
render_for_completion,
)
from sglang.srt.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
......@@ -57,30 +47,12 @@ class OpenAIServingChat(OpenAIServingBase):
"""Handler for /v1/chat/completions requests"""
def __init__(
self, tokenizer_manager: TokenizerManager, template_manager: TemplateManager
self,
tokenizer_manager: TokenizerManager,
template_manager: TemplateManager,
):
super().__init__(tokenizer_manager)
self.template_manager = template_manager
self.use_harmony = (
self.tokenizer_manager.model_config.hf_config.model_type == "gpt_oss"
)
if self.use_harmony:
from sglang.srt.function_call.harmony_tool_parser import (
HarmonyToolCallParser,
)
self.harmony_tool_parser = HarmonyToolCallParser()
# NOTE While OpenAI's chat completion API supports browsing
# for some models, currently vLLM doesn't support it. Please use the
# Responses API instead.
self.supports_browsing = False
self.browser_tool = None
# NOTE: Chat completion API does not support code interpreter.
# Please use the Responses API instead.
self.supports_code_interpreter = False
self.python_tool = None
def _request_id_prefix(self) -> str:
return "chatcmpl-"
......@@ -97,6 +69,18 @@ class OpenAIServingChat(OpenAIServingBase):
):
return "Tools cannot be empty if tool choice is set to required."
max_output_tokens = request.max_completion_tokens or request.max_tokens
server_context_length = self.tokenizer_manager.server_args.context_length
if (
max_output_tokens
and server_context_length
and max_output_tokens > server_context_length
):
return (
f"max_completion_tokens is too large: {max_output_tokens}."
f"This model supports at most {server_context_length} completion tokens."
)
return None
def _convert_to_internal_request(
......@@ -107,66 +91,43 @@ class OpenAIServingChat(OpenAIServingBase):
is_multimodal = self.tokenizer_manager.model_config.is_multimodal
# Process messages and apply chat template
if not self.use_harmony:
processed_messages = self._process_messages(request, is_multimodal)
# Build sampling parameters
sampling_params = self._build_sampling_params(
request,
processed_messages.stop,
processed_messages.tool_call_constraint,
)
processed_messages = self._process_messages(request, is_multimodal)
# Handle single vs multiple requests
if is_multimodal:
prompt_kwargs = {"text": processed_messages.prompt}
else:
if isinstance(processed_messages.prompt_ids, str):
prompt_kwargs = {"text": processed_messages.prompt_ids}
else:
prompt_kwargs = {"input_ids": processed_messages.prompt_ids}
adapted_request = GenerateReqInput(
**prompt_kwargs,
image_data=processed_messages.image_data,
video_data=processed_messages.video_data,
audio_data=processed_messages.audio_data,
sampling_params=sampling_params,
return_logprob=request.logprobs,
logprob_start_len=-1,
top_logprobs_num=request.top_logprobs or 0,
stream=request.stream,
return_text_in_logprobs=True,
modalities=processed_messages.modalities,
lora_path=request.lora_path,
bootstrap_host=request.bootstrap_host,
bootstrap_port=request.bootstrap_port,
bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states,
rid=request.rid,
)
# Build sampling parameters
sampling_params = self._build_sampling_params(
request,
processed_messages.stop,
processed_messages.tool_call_constraint,
)
# Handle single vs multiple requests
if is_multimodal:
prompt_kwargs = {"text": processed_messages.prompt}
else:
processed_messages, prompt_ids = self._make_request_with_harmony(request)
adapted_request = GenerateReqInput(
input_ids=prompt_ids,
sampling_params=self._build_sampling_params(
request,
request.stop,
tool_call_constraint=None,
),
stream=request.stream,
return_logprob=request.logprobs,
logprob_start_len=-1,
top_logprobs_num=request.top_logprobs or 0,
return_text_in_logprobs=True,
lora_path=request.lora_path,
bootstrap_host=request.bootstrap_host,
bootstrap_port=request.bootstrap_port,
bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states,
rid=request.rid,
)
if isinstance(processed_messages.prompt_ids, str):
prompt_kwargs = {"text": processed_messages.prompt_ids}
else:
prompt_kwargs = {"input_ids": processed_messages.prompt_ids}
adapted_request = GenerateReqInput(
**prompt_kwargs,
image_data=processed_messages.image_data,
video_data=processed_messages.video_data,
audio_data=processed_messages.audio_data,
sampling_params=sampling_params,
return_logprob=request.logprobs,
logprob_start_len=-1,
top_logprobs_num=request.top_logprobs or 0,
stream=request.stream,
return_text_in_logprobs=True,
modalities=processed_messages.modalities,
lora_path=request.lora_path,
bootstrap_host=request.bootstrap_host,
bootstrap_port=request.bootstrap_port,
bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states,
rid=request.rid,
)
return adapted_request, request
......@@ -251,14 +212,16 @@ class OpenAIServingChat(OpenAIServingBase):
tokenize=True,
add_generation_prompt=True,
tools=tools,
reasoning_effort=request.reasoning_effort,
builtin_tools=[],
**(
request.chat_template_kwargs if request.chat_template_kwargs else {}
),
)
except Exception:
# This except branch will be triggered when the chosen model
# has a different tools input format that is not compatible
# with openAI's apply_chat_template tool_call format, like Mistral.
# This except branch will be triggered when the chosen model
# has a different tools input format that is not compatible
# with openAI's apply_chat_template tool_call format, like Mistral.
tools = (
[t if "function" in t else {"function": t} for t in tools]
if tools
......@@ -269,6 +232,8 @@ class OpenAIServingChat(OpenAIServingBase):
tokenize=True,
add_generation_prompt=True,
tools=tools,
reasoning_effort=request.reasoning_effort,
builtin_tools=[],
**(
request.chat_template_kwargs if request.chat_template_kwargs else {}
),
......@@ -459,12 +424,6 @@ class OpenAIServingChat(OpenAIServingBase):
cached_tokens = {}
hidden_states = {}
# Harmony tracking
if self.use_harmony:
harmony_parsers = [
get_streamable_parser_for_assistant() for _ in range(request.n)
]
try:
async for content in self.tokenizer_manager.generate_request(
adapted_request, raw_request
......@@ -511,58 +470,14 @@ class OpenAIServingChat(OpenAIServingBase):
)
yield f"data: {chunk.model_dump_json()}\n\n"
# Process content delta
if self.use_harmony:
harmony_parser = harmony_parsers[index]
new_token_ids = content["output_ids"]
for token_id in new_token_ids:
harmony_parser.process(token_id)
is_final = harmony_parser.current_channel == "final"
is_analysis = harmony_parser.current_channel == "analysis"
delta = harmony_parser.last_content_delta or ""
if is_analysis:
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(reasoning_content=delta),
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=int(time.time()),
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n"
continue
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(content=delta if delta else None),
finish_reason=None,
matched_stop=None,
logprobs=choice_logprobs,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=int(time.time()),
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n"
continue
else:
stream_buffer = stream_buffers.get(index, "")
delta = content["text"][len(stream_buffer) :]
stream_buffers[index] = stream_buffer + delta
stream_buffer = stream_buffers.get(index, "")
delta = content["text"][len(stream_buffer) :]
stream_buffers[index] = stream_buffer + delta
# Handle reasoning content
if (
self.tokenizer_manager.server_args.reasoning_parser
and request.separate_reasoning
and not self.use_harmony
):
reasoning_text, delta = self._process_reasoning_stream(
index, delta, reasoning_parser_dict, content, request
......@@ -581,27 +496,8 @@ class OpenAIServingChat(OpenAIServingBase):
)
yield f"data: {chunk.model_dump_json()}\n\n"
if self.use_harmony and not is_final:
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(reasoning_content=delta),
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=int(time.time()),
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n"
# Handle tool calls
# TODO: support tool call parsing for harmony
if (
request.tool_choice != "none"
and request.tools
and not self.use_harmony
):
if request.tool_choice != "none" and request.tools:
async for chunk in self._process_tool_call_stream(
index,
delta,
......@@ -765,76 +661,6 @@ class OpenAIServingChat(OpenAIServingBase):
finish_reason = ret_item["meta_info"]["finish_reason"]
text = ret_item["text"]
output_ids = ret_item["output_ids"]
if self.use_harmony:
parser = parse_output_into_messages(output_ids)
output_msgs = parser.messages
if len(output_msgs) == 0:
# The generation has stopped during reasoning.
is_tool_call = False
reasoning_content = parser.current_content
final_content = None
elif len(output_msgs) == 1:
# The generation has stopped during final message.
is_tool_call = False
reasoning_content = output_msgs[0].content[0].text
final_content = parser.current_content
else:
if len(output_msgs) != 2:
raise ValueError(
"Expected 2 output messages (reasoning and final), "
f"but got {len(output_msgs)}."
)
reasoning_msg, final_msg = output_msgs
reasoning_content = reasoning_msg.content[0].text
final_content = final_msg.content[0].text
is_tool_call = final_msg.recipient is not None
if is_tool_call:
# Extract tool call information from final message
tool_call = (
self.harmony_tool_parser.extract_tool_calls_from_message(
final_msg
)
)
tool_calls = [tool_call] if tool_call else []
message = ChatMessage(
role="assistant",
reasoning_content=reasoning_content,
content=None, # Tool calls don't have regular content
tool_calls=tool_calls,
)
else:
# Normal message
message = ChatMessage(
role="assistant",
reasoning_content=reasoning_content,
content=final_content,
)
if is_tool_call:
finish_reason_type = "tool_calls"
elif finish_reason:
finish_reason_type = (
finish_reason["type"] if finish_reason else "stop"
)
else:
finish_reason_type = "stop"
choice_data = ChatCompletionResponseChoice(
index=idx,
message=message,
logprobs=choice_logprobs,
finish_reason=finish_reason_type,
matched_stop=(
finish_reason["matched"]
if finish_reason and "matched" in finish_reason
else None
),
)
choices.append(choice_data)
continue
# Handle reasoning content
reasoning_text = None
......@@ -1184,33 +1010,3 @@ class OpenAIServingChat(OpenAIServingBase):
return f"data: {chunk.model_dump_json()}\n\n"
return None
def _make_request_with_harmony(
self,
request: ChatCompletionRequest,
):
messages: list[OpenAIMessage] = []
# Add system message.
# In Chat Completion API, browsing is enabled by default if the model
# supports it.
assert not self.supports_browsing
assert not self.supports_code_interpreter
sys_msg = get_system_message(
reasoning_effort=request.reasoning_effort,
browser_description=None,
python_description=None,
)
messages.append(sys_msg)
# Add developer message.
dev_msg = get_developer_message()
messages.append(dev_msg)
# Add user message.
for chat_msg in request.messages:
messages.append(parse_chat_input(chat_msg))
# Render prompt token ids.
prompt_token_ids = render_for_completion(messages)
return messages, prompt_token_ids
......@@ -11,6 +11,7 @@ from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import ToolCallItem
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector
from sglang.srt.function_call.gpt_oss_detector import GptOssDetector
from sglang.srt.function_call.kimik2_detector import KimiK2Detector
from sglang.srt.function_call.llama32_detector import Llama32Detector
from sglang.srt.function_call.mistral_detector import MistralDetector
......@@ -41,6 +42,7 @@ class FunctionCallParser:
"qwen3_coder": Qwen3CoderDetector,
"glm45": Glm4MoeDetector,
"step3": Step3Detector,
"gpt-oss": GptOssDetector,
}
def __init__(self, tools: List[Tool], tool_call_parser: str):
......
import json
import logging
import re
from typing import List
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
ToolCallItem,
_GetInfoFunc,
)
logger = logging.getLogger(__name__)
class GptOssDetector(BaseFormatDetector):
"""
Detector for T4-style function calls with channel format.
Supports two formats:
1. Direct function call: <|channel|>commentary to={namespace.function}<|constrain|>json<|message|>{args}<|call|>
2. Commentary with action plan: <|channel|>commentary<|message|>{content}<|end|>
For parallel function calls, each call is self-contained and starts with its own channel:
<|channel|>commentary to=functions.get_weather<|constrain|>json<|message|>{"location":"SF"}<|call|>
<|channel|>commentary to=functions.search<|constrain|>json<|message|>{"query":"SF attractions"}<|call|>
Examples:
Single: <|channel|>commentary to=functions.get_weather<|constrain|>json<|message|>{"location":"San Francisco"}<|call|>commentary
Multiple: <|channel|>commentary to=functions.get_weather<|constrain|>json<|message|>{"location":"Paris"}<|call|>commentary<|channel|>commentary to=functions.search<|constrain|>json<|message|>{"query":"Paris tourism"}<|call|>
With Action Plan: <|channel|>commentary<|message|>**Action plan**: 1. Do X 2. Do Y<|end|><|start|>assistant<|channel|>commentary to=functions.x<|constrain|>json<|message|>{"template": "basic_html", "path": "index.html"}<|call|>
"""
def __init__(self):
super().__init__()
self.bot_token = "<|start|>assistant<|channel|>commentary"
self.eot_token = "<|call|>"
# TODO: no clear indication how parallel tool call response format is
self.tool_call_separator = ""
# Pattern for complete function calls with to= parameter
# Handles both <|call|> and <|call|>commentary endings
# Also handles optional <|start|>assistant prefix and whitespace after function name
self.function_call_pattern = re.compile(
r"(?:<\|start\|>assistant)?<\|channel\|>commentary to=([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)\s*"
r"<\|constrain\|>json<\|message\|>(.*?)<\|call\|>(?:commentary)?",
re.DOTALL,
)
# Pattern for streaming function calls (incomplete)
# Also handles optional whitespace after function name
self.streaming_pattern = re.compile(
r"(?:<\|start\|>assistant)?<\|channel\|>commentary to=([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)\s*"
r"<\|constrain\|>json<\|message\|>(.*)",
re.DOTALL,
)
# Pattern for commentary with action plan (no to= parameter)
self.commentary_pattern = re.compile(
r"<\|channel\|>commentary<\|message\|>(.*?)<\|end\|>",
re.DOTALL,
)
self._last_arguments = ""
def has_tool_call(self, text: str) -> bool:
"""Check if text contains TypeScript-style function call markers."""
return self.bot_token in text
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""Parse TypeScript-style function calls from complete text."""
if not self.has_tool_call(text):
return StreamingParseResult(normal_text=text, calls=[])
tool_indices = self._get_tool_indices(tools)
calls = []
tool_index = 0
# Process the entire text to handle mixed commentary and tool calls
normal_text_parts = []
# Find all commentary sections (both with and without to=)
all_commentary_pattern = re.compile(
r"<\|channel\|>commentary(?:\s+to=[^<]*)?<\|message\|>(.*?)(?:<\|end\|>|<\|call\|>)",
re.DOTALL,
)
# Track processed positions to avoid double-processing
processed_ranges = []
# First, extract all tool calls
for match in self.function_call_pattern.finditer(text):
full_function_name = match.group(1)
args_content = match.group(2)
processed_ranges.append((match.start(), match.end()))
function_name = (
full_function_name.split(".")[-1]
if "." in full_function_name
else full_function_name
)
try:
arguments = json.loads(args_content) if args_content.strip() else {}
except json.JSONDecodeError:
continue
if function_name in tool_indices:
calls.append(
ToolCallItem(
tool_index=tool_index,
name=function_name,
parameters=json.dumps(arguments, ensure_ascii=False),
)
)
tool_index += 1
# Then, find non-tool-call commentary sections for normal text
for match in all_commentary_pattern.finditer(text):
# Check if this match overlaps with any processed tool call
match_start, match_end = match.start(), match.end()
is_tool_call = any(
start <= match_start < end or start < match_end <= end
for start, end in processed_ranges
)
# If this commentary is not part of a tool call, include it in normal text
if not is_tool_call:
content = match.group(1).strip()
if content:
normal_text_parts.append(content)
# Handle remaining text after all matches
if processed_ranges:
last_match_end = max(end for _, end in processed_ranges)
if last_match_end < len(text):
remaining_text = text[last_match_end:]
# Clean up <|start|>assistant prefixes and extract final content
# Remove standalone <|start|>assistant prefixes
remaining_text = re.sub(r"<\|start\|>assistant(?!\w)", "", remaining_text)
# Extract content from final channel if present
final_pattern = re.compile(
r"<\|channel\|>final<\|message\|>(.*?)(?:<\|return\|>|$)", re.DOTALL
)
final_match = final_pattern.search(remaining_text)
if final_match:
# Get everything before final channel + final channel content
before_final = remaining_text[: final_match.start()].strip()
final_content = final_match.group(1).strip()
parts = []
if before_final:
parts.append(before_final)
if final_content:
parts.append(final_content)
remaining_text = " ".join(parts) if parts else ""
remaining_text = remaining_text.strip()
if remaining_text:
normal_text_parts.append(remaining_text)
# Combine all normal text parts
final_normal_text = " ".join(part for part in normal_text_parts if part).strip()
return StreamingParseResult(normal_text=final_normal_text, calls=calls)
def parse_streaming_increment(
self, new_text: str, tools: List[Tool]
) -> StreamingParseResult:
"""Parse incremental streaming text for TypeScript-style function calls."""
self._buffer += new_text
current_text = self._buffer
# Check if we have a tool call
has_tool_call = "<|channel|>commentary to=" in current_text
if not has_tool_call and current_text:
# Check for commentary without function calls
commentary_match = self.commentary_pattern.search(current_text)
if commentary_match:
commentary_content = commentary_match.group(1)
self._buffer = current_text[commentary_match.end() :]
return StreamingParseResult(normal_text=commentary_content, calls=[])
# Check for final channel content
final_pattern = re.compile(
r"<\|channel\|>final<\|message\|>(.*?)(?:<\|return\|>|$)",
re.DOTALL,
)
final_match = final_pattern.search(current_text)
if final_match:
final_content = final_match.group(1).strip()
self._buffer = ""
return StreamingParseResult(normal_text=final_content, calls=[])
self._buffer = ""
return StreamingParseResult(normal_text=new_text, calls=[])
if not hasattr(self, "_tool_indices"):
self._tool_indices = self._get_tool_indices(tools)
calls = []
try:
# Check for streaming function call
match = self.streaming_pattern.search(current_text)
if match:
full_function_name = match.group(1)
args_content = match.group(2)
function_name = (
full_function_name.split(".")[-1]
if "." in full_function_name
else full_function_name
)
# Initialize state if this is the first tool call
if self.current_tool_id == -1:
self.current_tool_id = 0
self.prev_tool_call_arr = []
self.streamed_args_for_tool = [""]
# Ensure we have enough entries in tracking arrays
while len(self.prev_tool_call_arr) <= self.current_tool_id:
self.prev_tool_call_arr.append({})
while len(self.streamed_args_for_tool) <= self.current_tool_id:
self.streamed_args_for_tool.append("")
if not self.current_tool_name_sent:
calls.append(
ToolCallItem(
tool_index=self.current_tool_id,
name=function_name,
parameters="",
)
)
self.current_tool_name_sent = True
# Store the tool call info
self.prev_tool_call_arr[self.current_tool_id] = {
"name": function_name,
"arguments": {},
}
self.streamed_args_for_tool[self.current_tool_id] = ""
# Check if we have a complete function call
complete_match = self.function_call_pattern.search(current_text)
if complete_match:
args_content = complete_match.group(2)
try:
parsed_args = json.loads(args_content)
self.prev_tool_call_arr[self.current_tool_id][
"arguments"
] = parsed_args
# Send complete arguments if we haven't sent them yet
if not self.streamed_args_for_tool[self.current_tool_id]:
# Send the complete arguments as JSON string
calls.append(
ToolCallItem(
tool_index=self.current_tool_id,
name=None,
parameters=json.dumps(
parsed_args, ensure_ascii=False
),
)
)
self.streamed_args_for_tool[self.current_tool_id] = (
json.dumps(parsed_args, ensure_ascii=False)
)
except json.JSONDecodeError:
pass
# Remove the completed function call from buffer
remaining_after_call = current_text[complete_match.end() :]
# Clean up <|start|>assistant prefixes and extract final content
remaining_after_call = re.sub(
r"<\|start\|>assistant(?!\w)", "", remaining_after_call
)
# Extract content from final channel if present
final_pattern = re.compile(
r"<\|channel\|>final<\|message\|>(.*?)(?:<\|return\|>|$)",
re.DOTALL,
)
final_match = final_pattern.search(remaining_after_call)
if final_match:
before_final = remaining_after_call[
: final_match.start()
].strip()
final_content = final_match.group(1).strip()
parts = []
if before_final:
parts.append(before_final)
if final_content:
parts.append(final_content)
remaining_after_call = " ".join(parts) if parts else ""
self._buffer = remaining_after_call.strip()
# Reset state for next tool call
self.current_tool_name_sent = False
self.current_tool_id += 1
# Return final content if available
final_text = ""
if final_match and final_content:
final_text = final_content
elif remaining_after_call:
final_text = remaining_after_call
return StreamingParseResult(normal_text=final_text, calls=calls)
return StreamingParseResult(normal_text="", calls=calls)
except Exception as e:
logger.error(f"Error in parse_streaming_increment: {e}")
return StreamingParseResult(normal_text=current_text, calls=[])
def structure_info(self) -> _GetInfoFunc:
raise NotImplementedError()
def build_ebnf(self, tools: List[Tool]) -> str:
raise NotImplementedError()
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Harmony tool call parser for processing tool calls in harmony models."""
import uuid
from typing import List, Optional, Tuple
from sglang.srt.entrypoints.openai.protocol import (
ChatMessage,
FunctionResponse,
ToolCall,
)
class HarmonyToolCallParser:
"""Parser for extracting tool calls from harmony model outputs."""
def extract_tool_calls_from_message(self, msg) -> Optional[ToolCall]:
"""
Extract tool call from a single message if it's a tool call.
Args:
msg: The harmony message
Returns:
ToolCall if the message is a tool call, None otherwise
"""
if (
msg.channel == "commentary"
and msg.recipient
and msg.recipient.startswith("functions.")
):
function_name = msg.recipient.split(".")[-1]
arguments = msg.content[0].text if msg.content else "{}"
return ToolCall(
id=f"call_{uuid.uuid4().hex[:24]}",
function=FunctionResponse(
name=function_name,
arguments=arguments,
),
)
return None
def process_streaming_chunk(
self,
harmony_parser,
index: int,
tool_call_trackers: dict,
stream_buffers: dict,
) -> Tuple[Optional[dict], bool, Optional[str]]:
"""
Process a streaming chunk for tool calls.
Args:
harmony_parser: The harmony parser instance
index: The choice index
tool_call_trackers: Dict tracking tool calls per choice
stream_buffers: Dict for buffering content
Returns:
Tuple of (tool_call_data, is_tool_call, delta)
"""
# Check if we're in a tool call
is_tool_call = (
harmony_parser.current_channel == "commentary"
and harmony_parser.current_recipient
and harmony_parser.current_recipient.startswith("functions.")
)
delta = harmony_parser.last_content_delta or ""
tool_call_data = None
if is_tool_call:
# Handle tool call streaming
function_name = harmony_parser.current_recipient.split(".")[-1]
# Track tool call indices per choice
if index not in tool_call_trackers:
tool_call_trackers[index] = {"count": 0, "current_function": None}
# Check if we just started a new tool call
tool_call_tracker = tool_call_trackers[index]
if tool_call_tracker["current_function"] != function_name:
# New tool call started
tool_call_tracker["current_function"] = function_name
tool_call_index = tool_call_tracker["count"]
tool_call_tracker["count"] += 1
# Store the tool call index for this function
tool_call_key = f"{index}_{function_name}"
stream_buffers[tool_call_key] = {
"index": tool_call_index,
"content": "",
}
tool_call_data = {
"id": f"call_{uuid.uuid4().hex[:24]}",
"index": tool_call_index,
"function_name": function_name,
"arguments": delta,
"is_first_chunk": True,
}
else:
# Subsequent chunks for the same tool call
tool_call_key = f"{index}_{function_name}"
tool_call_index = stream_buffers[tool_call_key]["index"]
tool_call_data = {
"id": None,
"index": tool_call_index,
"function_name": None,
"arguments": delta,
"is_first_chunk": False,
}
stream_buffers[tool_call_key]["content"] += delta
return tool_call_data, is_tool_call, delta
import re
from typing import Dict, Optional, Tuple, Type
......@@ -185,6 +186,320 @@ class KimiDetector(BaseReasoningFormatDetector):
)
class GptOssDetector(BaseReasoningFormatDetector):
"""
Detector for T4-style reasoning format.
Assumes reasoning format with two channels:
<|channel|>analysis<|message|>...reasoning content...<|end|>
<|start|>assistant<|channel|>final<|message|>...final answer...<|return|>
Returns content from 'analysis' channel as reasoning_text
and content from 'final' channel as normal_text.
Args:
stream_reasoning (bool): If False, accumulates reasoning content until complete.
If True, streams reasoning content as it arrives.
"""
def __init__(self, stream_reasoning: bool = True, force_reasoning: bool = True):
# TypeScript uses channel tokens instead of simple start/end tokens
super().__init__(
"<|channel|>analysis<|message|>",
"<|end|>",
force_reasoning=True,
stream_reasoning=stream_reasoning,
)
self.final_channel_start = "<|start|>assistant<|channel|>final<|message|>"
self.final_channel_end = "<|return|>"
self._in_final_channel = False
self._analysis_complete = False
self._in_reasoning = True
def detect_and_parse(self, text: str) -> StreamingParseResult:
"""
One-time parsing: Detects and parses both analysis and final channels.
Tool call channels are preserved in normal_text for downstream processing.
HACK: Also handles simplified format where text starts with "analysis" and transitions
to "assistantfinal" without full channel markers.
"""
# HACK: Handle simplified format (analysis...assistantfinal) without channel markers
if (
text.startswith("analysis")
and "assistantfinal" in text
and "<|channel|>" not in text
):
# Split on "assistantfinal"
parts = text.split("assistantfinal", 1)
self._in_reasoning = False
if len(parts) == 2:
reasoning_text = parts[0][
len("analysis") :
].strip() # Remove "analysis" prefix
normal_text = parts[1].strip()
return StreamingParseResult(
normal_text=normal_text, reasoning_text=reasoning_text
)
reasoning_parts = []
normal_parts = []
current_pos = 0
# Process text sequentially to preserve tool calls between analysis sections
while current_pos < len(text):
# Look for next analysis channel
analysis_start_idx = text.find(self.think_start_token, current_pos)
if analysis_start_idx == -1:
# No more analysis channels, rest goes to remaining
break
# Preserve any content before this analysis channel (could include tool calls)
if analysis_start_idx > current_pos:
between_content = text[current_pos:analysis_start_idx]
# This content will be added to normal_parts later
normal_parts.append(between_content)
# Extract analysis content
analysis_content_start = analysis_start_idx + len(self.think_start_token)
analysis_end_idx = text.find(self.think_end_token, analysis_content_start)
if analysis_end_idx != -1:
reasoning_parts.append(
text[analysis_content_start:analysis_end_idx].strip()
)
current_pos = analysis_end_idx + len(self.think_end_token)
else:
# Analysis not complete
reasoning_parts.append(text[analysis_content_start:].strip())
reasoning_text = "".join(reasoning_parts)
return StreamingParseResult(reasoning_text=reasoning_text)
# Add any remaining text after all analysis sections
if current_pos < len(text):
remaining = text[current_pos:]
normal_parts.append(remaining)
# Process non-analysis content for commentary sections
full_normal_text = "".join(normal_parts)
# Extract reasoning from non-tool-call commentary sections
# Tool calls have "to=" in their header, regular commentary does not
commentary_pattern = re.compile(
r"<\|start\|>assistant<\|channel\|>commentary<\|message\|>(.*?)(?:<\|end\|>|<\|call\|>)",
re.DOTALL,
)
cleaned_text = full_normal_text
for match in reversed(list(commentary_pattern.finditer(full_normal_text))):
# Check if this commentary is a tool call by looking at the text before <|message|>
match_start = match.start()
# Find the start of this commentary section
commentary_start = full_normal_text.rfind(
"<|channel|>commentary", 0, match_start
)
if commentary_start != -1:
# Extract text between "commentary" and "<|message|>"
message_pos = full_normal_text.find("<|message|>", commentary_start)
if message_pos != -1:
between_text = full_normal_text[commentary_start:message_pos]
# If no "to=" found, this is regular commentary (reasoning content)
if " to=" not in between_text:
content = match.group(1).strip()
reasoning_parts.append(content)
# Remove this commentary section from normal text
cleaned_text = (
cleaned_text[: match.start()] + cleaned_text[match.end() :]
)
full_normal_text = cleaned_text
# Combine all reasoning parts
reasoning_text = "".join(reasoning_parts)
# Process full_normal_text for final output
normal_text = ""
if self.final_channel_start in full_normal_text:
final_start = full_normal_text.find(self.final_channel_start)
final_content_start = final_start + len(self.final_channel_start)
final_end = full_normal_text.find(
self.final_channel_end, final_content_start
)
if final_end != -1:
# Extract content before final channel (includes tool calls)
before_final = full_normal_text[:final_start].strip()
# Extract ONLY the final channel content (not the channel markers)
final_text = full_normal_text[final_content_start:final_end].strip()
# Extract content after final channel
after_final = full_normal_text[
final_end + len(self.final_channel_end) :
].strip()
# For tool calls + final answer: concatenate tool calls with final text
parts = []
if before_final:
parts.append(before_final)
if final_text:
parts.append(final_text)
if after_final:
parts.append(after_final)
normal_text = " ".join(parts)
else:
# Final channel not complete - extract what we have
# Look for just <|channel|>final<|message|> without <|return|>
alt_final_start = full_normal_text.find("<|channel|>final<|message|>")
if alt_final_start != -1:
before_alt_final = full_normal_text[:alt_final_start].strip()
alt_final_content = full_normal_text[
alt_final_start + len("<|channel|>final<|message|>") :
].strip()
parts = []
if before_alt_final:
parts.append(before_alt_final)
if alt_final_content:
parts.append(alt_final_content)
normal_text = " ".join(parts)
else:
normal_text = full_normal_text.strip()
else:
# No final channel, treat all as normal text (includes tool calls)
normal_text = full_normal_text.strip()
return StreamingParseResult(
normal_text=normal_text, reasoning_text=reasoning_text
)
def parse_streaming_increment(self, new_text: str) -> StreamingParseResult:
"""
Streaming incremental parsing for GPT-OSS format.
This is a simplified streaming implementation that accumulates content
and delegates to the non-streaming parser for complex multi-channel parsing.
TODO: Implement proper incremental parsing for better streaming performance.
"""
self._buffer += new_text
if not self._in_reasoning:
return StreamingParseResult(normal_text=new_text)
# Check if we have complete sections to process
# For GPT-OSS, we need to wait for complete channel sections
# HACK: For now, use simplified approach - wait for key markers before processing
key_markers = ["<|end|>", "<|call|>", "<|return|>", "assistantfinal"]
has_complete_section = any(marker in self._buffer for marker in key_markers)
if not has_complete_section:
# Still accumulating, don't process yet
return StreamingParseResult()
# Handle simplified format (analysis...assistantfinal) with true incremental streaming
if (
"<|channel|>" not in self._buffer
): # Simplified format without channel markers
if self._buffer.startswith("analysis"):
# Check if we have the transition to assistantfinal
if "assistantfinal" in self._buffer:
self._in_reasoning = False
# Complete reasoning section - extract and stream it
parts = self._buffer.split("assistantfinal", 1)
reasoning_text = parts[0][len("analysis") :].strip()
final_content = parts[1].strip()
# Clear buffer and return both reasoning and final content
self._buffer = ""
return StreamingParseResult(
reasoning_text=reasoning_text if self.stream_reasoning else "",
normal_text=final_content,
)
elif self.stream_reasoning:
# Stream reasoning content incrementally as it arrives
current_reasoning = self._buffer[len("analysis") :].strip()
self._buffer = ""
return StreamingParseResult(reasoning_text=current_reasoning)
else:
# Wait for assistantfinal
return StreamingParseResult()
elif self._buffer.startswith("assistantfinal"):
# Direct final content without analysis
final_content = self._buffer[len("assistantfinal") :].strip()
self._buffer = ""
return StreamingParseResult(normal_text=final_content)
# For full channel format, process sections as they complete
result = StreamingParseResult()
# Process complete analysis sections
while (
self.think_start_token in self._buffer
and self.think_end_token in self._buffer
):
start_idx = self._buffer.find(self.think_start_token)
start_pos = start_idx + len(self.think_start_token)
end_pos = self._buffer.find(self.think_end_token, start_pos)
if end_pos != -1:
reasoning_content = self._buffer[start_pos:end_pos].strip()
if self.stream_reasoning and reasoning_content:
result.reasoning_text += reasoning_content
# Remove processed analysis section
self._buffer = (
self._buffer[:start_idx]
+ self._buffer[end_pos + len(self.think_end_token) :]
)
else:
break
# Process complete commentary sections
commentary_pattern = re.compile(
r"<\|start\|>assistant<\|channel\|>commentary<\|message\|>(.*?)(?:<\|end\|>|<\|call\|>)",
re.DOTALL,
)
for match in reversed(list(commentary_pattern.finditer(self._buffer))):
# Check if this is a tool call
start_pos = match.start()
commentary_content = match.group(1).strip()
if self.stream_reasoning and commentary_content:
result.reasoning_text += commentary_content
# Remove this commentary section
self._buffer = self._buffer[: match.start()] + self._buffer[match.end() :]
# Clean up any standalone <|start|>assistant
self._buffer = re.sub(
r"<\|start\|>assistant(?=<\|start\|>assistant)", "", self._buffer
)
# Handle final channel completion
if self.final_channel_start in self._buffer:
final_start = self._buffer.find(self.final_channel_start)
final_content_start = final_start + len(self.final_channel_start)
# Check if final channel is complete
final_end = self._buffer.find(self.final_channel_end, final_content_start)
if final_end != -1:
# Complete final channel - process everything
final_result = self.detect_and_parse(self._buffer)
self._buffer = ""
return StreamingParseResult(
normal_text=final_result.normal_text,
reasoning_text=result.reasoning_text + final_result.reasoning_text,
)
else:
# Extract content before final channel (e.g. tool calls)
before_final = self._buffer[:final_start]
if before_final:
# Output tool calls for processing
result.normal_text += before_final
# Keep the final channel part in buffer
self._buffer = self._buffer[final_start:]
return result
class ReasoningParser:
"""
Parser that handles both streaming and non-streaming scenarios for extracting
......@@ -203,6 +518,7 @@ class ReasoningParser:
"glm45": Qwen3Detector,
"kimi": KimiDetector,
"step3": DeepSeekR1Detector,
"gpt-oss": GptOssDetector,
}
def __init__(
......
......@@ -1190,7 +1190,7 @@ class ServerArgs:
parser.add_argument(
"--tool-call-parser",
type=str,
choices=[
choices=[ # TODO: use FunctionCallParser.DetectorMap.keys()
"qwen25",
"mistral",
"llama3",
......@@ -1200,6 +1200,7 @@ class ServerArgs:
"qwen3_coder",
"glm45",
"step3",
"gpt-oss",
],
default=ServerArgs.tool_call_parser,
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', 'qwen3_coder', 'glm45', and 'step3'.",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment