Unverified Commit 92cc32d9 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

Support v1/responses and use harmony in serving_chat (#8837)


Signed-off-by: default avatarXinyuan Tong <justinning0323@outlook.com>
Signed-off-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: default avatarXinyuan Tong <justinning0323@outlook.com>
Co-authored-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
parent cbbd685a
...@@ -29,6 +29,7 @@ runtime_common = [ ...@@ -29,6 +29,7 @@ runtime_common = [
"modelscope", "modelscope",
"msgspec", "msgspec",
"ninja", "ninja",
"openai-harmony==0.0.3",
"orjson", "orjson",
"outlines==0.1.11", "outlines==0.1.11",
"packaging", "packaging",
...@@ -96,7 +97,7 @@ srt_cpu = ["sglang[runtime_common]", "einops"] ...@@ -96,7 +97,7 @@ srt_cpu = ["sglang[runtime_common]", "einops"]
# https://vllm-ascend.readthedocs.io/en/latest/installation.html # https://vllm-ascend.readthedocs.io/en/latest/installation.html
srt_npu = ["sglang[runtime_common]"] srt_npu = ["sglang[runtime_common]"]
openai = ["openai>=1.0", "tiktoken"] openai = ["openai>=1.99.1", "tiktoken"]
anthropic = ["anthropic>=0.20.0"] anthropic = ["anthropic>=0.20.0"]
litellm = ["litellm>=1.0.0"] litellm = ["litellm>=1.0.0"]
torch_memory_saver = ["torch_memory_saver>=0.0.8"] torch_memory_saver = ["torch_memory_saver>=0.0.8"]
......
# SPDX-License-Identifier: Apache-2.0
# Copied from vLLM
import json
import logging
from abc import ABC, abstractmethod
from typing import Union
logger = logging.getLogger(__name__)
try:
from mcp import ClientSession
except ImportError:
logger.warning("Ignoring mcp import error")
from openai_harmony import Author, Message, Role, StreamState, TextContent
from sglang.srt.entrypoints.harmony_utils import (
get_encoding,
get_streamable_parser_for_assistant,
render_for_completion,
)
from sglang.srt.entrypoints.tool import Tool
class ConversationContext(ABC):
@abstractmethod
def append_output(self, output) -> None:
pass
@abstractmethod
async def call_tool(self) -> list[Message]:
pass
@abstractmethod
def need_builtin_tool_call(self) -> bool:
pass
@abstractmethod
def render_for_completion(self) -> list[int]:
pass
class SimpleContext(ConversationContext):
def __init__(self):
self.last_output = None
def append_output(self, output) -> None:
self.last_output = output
def need_builtin_tool_call(self) -> bool:
return False
async def call_tool(self) -> list[Message]:
raise NotImplementedError("Should not be called.")
def render_for_completion(self) -> list[int]:
raise NotImplementedError("Should not be called.")
class HarmonyContext(ConversationContext):
def __init__(
self,
messages: list,
tool_sessions: dict[str, Union["ClientSession", Tool]],
):
# TODO: Remove the hack of Union[ClientSession, Tool] by using MCP
# when demo.
self._messages = messages
self.tool_sessions = tool_sessions
self.parser = get_streamable_parser_for_assistant()
self.num_init_messages = len(messages)
# TODO
self.num_prompt_tokens = 0
self.num_cached_tokens = 0
self.num_output_tokens = 0
self.num_reasoning_tokens = 0
def append_output(self, output) -> None:
if isinstance(output, dict) and "output_ids" in output:
output_token_ids = output["output_ids"]
# TODO: REMOVE here:
# Very hacky, find the first occurrence of token 200006 and cut from there
try:
start_index = output_token_ids.index(200006)
output_token_ids = output_token_ids[start_index:]
except ValueError:
pass
for token_id in output_token_ids:
self.parser.process(token_id)
output_msgs = self.parser.messages
meta_info = output["meta_info"]
if isinstance(meta_info, dict):
if "prompt_token_ids" in meta_info:
self.num_prompt_tokens = meta_info["prompt_tokens"]
if "cached_tokens" in meta_info:
self.num_cached_tokens = meta_info["cached_tokens"]
if "completion_tokens" in meta_info:
self.num_output_tokens += meta_info["completion_tokens"]
else:
output_msgs = output
self._messages.extend(output_msgs)
@property
def messages(self) -> list:
return self._messages
def need_builtin_tool_call(self) -> bool:
last_msg = self.messages[-1]
recipient = last_msg.recipient
return recipient is not None and (
recipient.startswith("browser.") or recipient.startswith("python")
)
async def call_tool(self) -> list[Message]:
if not self.messages:
return []
last_msg = self.messages[-1]
recipient = last_msg.recipient
if recipient is not None:
if recipient.startswith("browser."):
return await self.call_search_tool(
self.tool_sessions["browser"], last_msg
)
elif recipient.startswith("python"):
return await self.call_python_tool(
self.tool_sessions["python"], last_msg
)
raise ValueError("No tool call found")
def render_for_completion(self) -> list[int]:
return render_for_completion(self.messages)
async def call_search_tool(
self, tool_session: Union["ClientSession", Tool], last_msg: Message
) -> list[Message]:
if isinstance(tool_session, Tool):
return await tool_session.get_result(self)
tool_name = last_msg.recipient.split(".")[1]
args = json.loads(last_msg.content[0].text)
result = await tool_session.call_tool(tool_name, args)
result_str = result.content[0].text
content = TextContent(text=result_str)
author = Author(role=Role.TOOL, name=last_msg.recipient)
return [Message(author=author, content=[content], recipient=Role.ASSISTANT)]
async def call_python_tool(
self, tool_session: Union["ClientSession", Tool], last_msg: Message
) -> list[Message]:
if isinstance(tool_session, Tool):
return await tool_session.get_result(self)
param = {
"code": last_msg.content[0].text,
}
result = await tool_session.call_tool("python", param)
result_str = result.content[0].text
content = TextContent(text=result_str)
author = Author(role=Role.TOOL, name="python")
return [
Message(
author=author,
content=[content],
channel=last_msg.channel,
recipient=Role.ASSISTANT,
)
]
class StreamingHarmonyContext(HarmonyContext):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.last_output = None
self.parser = get_streamable_parser_for_assistant()
self.encoding = get_encoding()
self.last_tok = None
@property
def messages(self) -> list:
return self.parser.messages
def append_output(self, output) -> None:
if isinstance(output, dict) and "output_ids" in output:
# RequestOutput from SGLang with outputs
output_token_ids = output["output_ids"]
# TODO: REMOVE here:
# Very hacky, find the first occurrence of token 200006 and cut from there
# Find the first occurrence of token 200006 and cut from there
try:
start_index = output_token_ids.index(200006)
output_token_ids = output_token_ids[start_index:]
except ValueError:
pass
for token_id in output_token_ids:
self.parser.process(token_id)
else:
# Handle the case of tool output in direct message format
assert len(output) == 1, "Tool output should be a single message"
msg = output[0]
# Sometimes the recipient is not set for tool messages,
# so we set it to "assistant"
if msg.author.role == Role.TOOL and msg.recipient is None:
msg.recipient = "assistant"
toks = self.encoding.render(msg)
for tok in toks:
self.parser.process(tok)
self.last_tok = toks[-1]
def is_expecting_start(self) -> bool:
return self.parser.state == StreamState.EXPECT_START
def is_assistant_action_turn(self) -> bool:
return self.last_tok in self.encoding.stop_tokens_for_assistant_actions()
def render_for_completion(self) -> list[int]:
# now this list of tokens as next turn's starting tokens
# `<|start|>assistant``,
# we need to process them in parser.
rendered_tokens = super().render_for_completion()
last_n = -1
to_process = []
while rendered_tokens[last_n] != self.last_tok:
to_process.append(rendered_tokens[last_n])
last_n -= 1
for tok in reversed(to_process):
self.parser.process(tok)
return rendered_tokens
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import datetime
import json
from collections.abc import Iterable
from typing import Literal, Optional, Union
from openai.types.responses import (
ResponseOutputItem,
ResponseOutputMessage,
ResponseOutputText,
ResponseReasoningItem,
)
from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall
from openai.types.responses.response_function_web_search import (
ActionFind,
ActionOpenPage,
ActionSearch,
ResponseFunctionWebSearch,
)
from openai.types.responses.response_reasoning_item import (
Content as ResponseReasoningTextContent,
)
from openai.types.responses.tool import Tool
from openai_harmony import (
Author,
Conversation,
DeveloperContent,
HarmonyEncodingName,
Message,
ReasoningEffort,
Role,
StreamableParser,
SystemContent,
TextContent,
ToolDescription,
load_harmony_encoding,
)
from sglang.srt.entrypoints.openai.protocol import ResponseInputOutputItem
from sglang.srt.utils import random_uuid
REASONING_EFFORT = {
"high": ReasoningEffort.HIGH,
"medium": ReasoningEffort.MEDIUM,
"low": ReasoningEffort.LOW,
}
_harmony_encoding = None
def get_encoding():
global _harmony_encoding
if _harmony_encoding is None:
_harmony_encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
return _harmony_encoding
def get_system_message(
model_identity: Optional[str] = None,
reasoning_effort: Optional[Literal["high", "medium", "low"]] = None,
start_date: Optional[str] = None,
browser_description: Optional[str] = None,
python_description: Optional[str] = None,
) -> Message:
sys_msg_content = SystemContent.new()
if model_identity is not None:
sys_msg_content = sys_msg_content.with_model_identity(model_identity)
if reasoning_effort is not None:
sys_msg_content = sys_msg_content.with_reasoning_effort(
REASONING_EFFORT[reasoning_effort]
)
if start_date is None:
start_date = datetime.datetime.now().strftime("%Y-%m-%d")
sys_msg_content = sys_msg_content.with_conversation_start_date(start_date)
if browser_description is not None:
sys_msg_content = sys_msg_content.with_tools(browser_description)
if python_description is not None:
sys_msg_content = sys_msg_content.with_tools(python_description)
sys_msg = Message.from_role_and_content(Role.SYSTEM, sys_msg_content)
return sys_msg
def get_developer_message(
instructions: Optional[str] = None, tools: Optional[list[Tool]] = None
) -> Message:
dev_msg_content = DeveloperContent.new()
if instructions is not None:
dev_msg_content = dev_msg_content.with_instructions(instructions)
if tools is not None:
function_tools = []
for tool in tools:
if tool.type in ("web_search_preview", "code_interpreter"):
# These are built-in tools that are added to the system message.
pass
elif tool.type == "function":
function_tools.append(tool)
else:
raise ValueError(f"tool type {tool.type} not supported")
if function_tools:
function_tool_descriptions = [
ToolDescription.new(
name=tool.name,
description=tool.description,
parameters=tool.parameters,
)
for tool in function_tools
]
dev_msg_content = dev_msg_content.with_function_tools(
function_tool_descriptions
)
dev_msg = Message.from_role_and_content(Role.DEVELOPER, dev_msg_content)
return dev_msg
def get_user_message(content: str) -> Message:
return Message.from_role_and_content(Role.USER, content)
def parse_response_input(
response_msg: ResponseInputOutputItem,
prev_responses: list[Union[ResponseOutputItem, ResponseReasoningItem]],
) -> Message:
if not isinstance(response_msg, dict):
response_msg = response_msg.model_dump()
if "type" not in response_msg or response_msg["type"] == "message":
role = response_msg["role"]
content = response_msg["content"]
if role == "system":
# User is trying to set a system message. Change it to:
# <|start|>developer<|message|># Instructions
# {instructions}<|end|>
role = "developer"
text_prefix = "Instructions:\n"
else:
text_prefix = ""
if isinstance(content, str):
msg = Message.from_role_and_content(role, text_prefix + content)
else:
contents = [TextContent(text=text_prefix + c["text"]) for c in content]
msg = Message.from_role_and_contents(role, contents)
elif response_msg["type"] == "function_call_output":
call_id = response_msg["call_id"]
call_response: Optional[ResponseFunctionToolCall] = None
for prev_response in reversed(prev_responses):
if (
isinstance(prev_response, ResponseFunctionToolCall)
and prev_response.call_id == call_id
):
call_response = prev_response
break
if call_response is None:
raise ValueError(f"No call message found for {call_id}")
msg = Message.from_author_and_content(
Author.new(Role.TOOL, f"functions.{call_response.name}"),
response_msg["output"],
)
elif response_msg["type"] == "reasoning":
content = response_msg["content"]
assert len(content) == 1
msg = Message.from_role_and_content(Role.ASSISTANT, content[0]["text"])
elif response_msg["type"] == "function_call":
msg = Message.from_role_and_content(Role.ASSISTANT, response_msg["arguments"])
msg = msg.with_channel("commentary")
msg = msg.with_recipient(f"functions.{response_msg['name']}")
msg = msg.with_content_type("json")
else:
raise ValueError(f"Unknown input type: {response_msg['type']}")
return msg
def parse_response_output(output: ResponseOutputItem) -> Message:
if isinstance(output, ResponseOutputMessage):
role = output.role
contents = [TextContent(text=c.text) for c in output.content]
msg = Message.from_role_and_contents(role, contents)
return msg
elif isinstance(output, ResponseFunctionToolCall):
msg = Message.from_role_and_content(Role.ASSISTANT, output.arguments)
msg = msg.with_channel("commentary")
msg = msg.with_recipient(output.name)
msg = msg.with_content_type("json")
return msg
else:
raise ValueError(f"Unknown output type: {type(output)}")
def parse_chat_input(chat_msg) -> Message:
role = chat_msg.role
content = chat_msg.content
if isinstance(content, str):
contents = [TextContent(text=content)]
else:
# TODO: Support refusal.
contents = [TextContent(text=c.text) for c in content]
msg = Message.from_role_and_contents(role, contents)
return msg
def render_for_completion(messages: list[Message]) -> list[int]:
conversation = Conversation.from_messages(messages)
token_ids = get_encoding().render_conversation_for_completion(
conversation, Role.ASSISTANT
)
return token_ids
def get_stop_tokens_for_assistant_actions() -> list[int]:
return get_encoding().stop_tokens_for_assistant_actions()
def get_streamable_parser_for_assistant() -> StreamableParser:
return StreamableParser(get_encoding(), role=Role.ASSISTANT)
def parse_output_message(message: Message):
if message.author.role != "assistant":
# This is a message from a tool to the assistant (e.g., search result).
# Don't include it in the final output for now. This aligns with
# OpenAI's behavior on models like o4-mini.
return []
output_items = []
recipient = message.recipient
if recipient is not None and recipient.startswith("browser."):
if len(message.content) != 1:
raise ValueError("Invalid number of contents in browser message")
content = message.content[0]
browser_call = json.loads(content.text)
# TODO: translate to url properly!
if recipient == "browser.search":
action = ActionSearch(
query=f"cursor:{browser_call.get('query', '')}", type="search"
)
elif recipient == "browser.open":
action = ActionOpenPage(
url=f"cursor:{browser_call.get('url', '')}", type="open_page"
)
elif recipient == "browser.find":
action = ActionFind(
pattern=browser_call["pattern"],
url=f"cursor:{browser_call.get('url', '')}",
type="find",
)
else:
raise ValueError(f"Unknown browser action: {recipient}")
web_search_item = ResponseFunctionWebSearch(
id=f"ws_{random_uuid()}",
action=action,
status="completed",
type="web_search_call",
)
output_items.append(web_search_item)
elif message.channel == "analysis":
for content in message.content:
reasoning_item = ResponseReasoningItem(
id=f"rs_{random_uuid()}",
type="reasoning",
summary=[],
content=[
ResponseReasoningTextContent(
text=content.text, type="reasoning_text"
)
],
status=None,
)
output_items.append(reasoning_item)
elif message.channel == "commentary":
if message.recipient.startswith("functions."):
function_name = message.recipient.split(".")[-1]
for content in message.content:
random_id = random_uuid()
response_item = ResponseFunctionToolCall(
arguments=content.text,
call_id=f"call_{random_id}",
type="function_call",
name=function_name,
id=f"ft_{random_id}",
)
output_items.append(response_item)
elif message.recipient.startswith("python") or message.recipient.startswith(
"browser"
):
for content in message.content:
reasoning_item = ResponseReasoningItem(
id=f"rs_{random_uuid()}",
type="reasoning",
summary=[],
content=[
ResponseReasoningTextContent(
text=content.text, type="reasoning_text"
)
],
status=None,
)
output_items.append(reasoning_item)
else:
raise ValueError(f"Unknown recipient: {message.recipient}")
elif message.channel == "final":
contents = []
for content in message.content:
output_text = ResponseOutputText(
text=content.text,
annotations=[], # TODO
type="output_text",
logprobs=None, # TODO
)
contents.append(output_text)
text_item = ResponseOutputMessage(
id=f"msg_{random_uuid()}",
content=contents,
role=message.author.role,
status="completed",
type="message",
)
output_items.append(text_item)
else:
raise ValueError(f"Unknown channel: {message.channel}")
return output_items
def parse_remaining_state(parser: StreamableParser):
if not parser.current_content:
return []
if parser.current_role != Role.ASSISTANT:
return []
current_recipient = parser.current_recipient
if current_recipient is not None and current_recipient.startswith("browser."):
return []
if parser.current_channel == "analysis":
reasoning_item = ResponseReasoningItem(
id=f"rs_{random_uuid()}",
type="reasoning",
summary=[],
content=[
ResponseReasoningTextContent(
text=parser.current_content, type="reasoning_text"
)
],
status=None,
)
return [reasoning_item]
elif parser.current_channel == "final":
output_text = ResponseOutputText(
content=[
ResponseReasoningTextContent(
text=parser.current_content, type="reasoning_text"
)
],
annotations=[], # TODO
type="output_text",
logprobs=None, # TODO
)
text_item = ResponseOutputMessage(
id=f"msg_{random_uuid()}",
content=[output_text],
role="assistant",
status="completed",
type="message",
)
return [text_item]
return []
def parse_output_into_messages(token_ids: Iterable[int]):
parser = get_streamable_parser_for_assistant()
for token_id in token_ids:
parser.process(token_id)
return parser
...@@ -32,6 +32,7 @@ from typing import AsyncIterator, Callable, Dict, Optional ...@@ -32,6 +32,7 @@ from typing import AsyncIterator, Callable, Dict, Optional
setattr(threading, "_register_atexit", lambda *args, **kwargs: None) setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import AsyncGenerator
import numpy as np import numpy as np
import orjson import orjson
...@@ -56,6 +57,7 @@ from sglang.srt.entrypoints.openai.protocol import ( ...@@ -56,6 +57,7 @@ from sglang.srt.entrypoints.openai.protocol import (
ErrorResponse, ErrorResponse,
ModelCard, ModelCard,
ModelList, ModelList,
ResponsesRequest,
ScoringRequest, ScoringRequest,
V1RerankReqInput, V1RerankReqInput,
) )
...@@ -147,6 +149,37 @@ async def lifespan(fast_api_app: FastAPI): ...@@ -147,6 +149,37 @@ async def lifespan(fast_api_app: FastAPI):
) )
server_args: ServerArgs = fast_api_app.server_args server_args: ServerArgs = fast_api_app.server_args
tool_server = None
if server_args.tool_server == "demo":
from sglang.srt.entrypoints.openai.tool_server import DemoToolServer
tool_server = DemoToolServer()
elif server_args.tool_server:
from sglang.srt.entrypoints.openai.tool_server import MCPToolServer
tool_server = MCPToolServer()
await tool_server.add_tool_server(server_args.tool_server)
try:
from sglang.srt.entrypoints.openai.serving_responses import (
OpenAIServingResponses,
)
fast_api_app.state.openai_serving_responses = OpenAIServingResponses(
_global_state.tokenizer_manager,
_global_state.template_manager,
enable_prompt_tokens_details=True,
enable_force_include_usage=True,
tool_server=tool_server,
)
except Exception as e:
# print stack trace
import traceback
traceback.print_exc()
logger.warning(f"Can not initialize OpenAIServingResponses, error: {e}")
if server_args.warmups is not None: if server_args.warmups is not None:
await execute_warmups( await execute_warmups(
server_args.disaggregation_mode, server_args.disaggregation_mode,
...@@ -843,6 +876,42 @@ async def v1_score_request(request: ScoringRequest, raw_request: Request): ...@@ -843,6 +876,42 @@ async def v1_score_request(request: ScoringRequest, raw_request: Request):
) )
@app.post("/v1/responses", dependencies=[Depends(validate_json_request)])
async def v1_responses_request(request: dict, raw_request: Request):
"""Endpoint for the responses API with reasoning support."""
request_obj = ResponsesRequest(**request)
result = await raw_request.app.state.openai_serving_responses.create_responses(
request_obj, raw_request
)
# Handle streaming responses
if isinstance(result, AsyncGenerator):
return StreamingResponse(
result,
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)
return result
@app.get("/v1/responses/{response_id}")
async def v1_retrieve_responses(response_id: str, raw_request: Request):
"""Retrieve a response by ID."""
return await raw_request.app.state.openai_serving_responses.retrieve_responses(
response_id
)
@app.post("/v1/responses/{response_id}/cancel")
async def v1_cancel_responses(response_id: str, raw_request: Request):
"""Cancel a background response."""
return await raw_request.app.state.openai_serving_responses.cancel_responses(
response_id
)
@app.api_route( @app.api_route(
"/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)] "/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
) )
......
...@@ -14,9 +14,18 @@ ...@@ -14,9 +14,18 @@
"""Pydantic models for OpenAI API protocol""" """Pydantic models for OpenAI API protocol"""
import time import time
import uuid
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, TypeAlias, Union
from openai.types.responses import (
ResponseFunctionToolCall,
ResponseInputItemParam,
ResponseOutputItem,
ResponseReasoningItem,
)
from openai.types.responses.response import ToolChoice
from openai.types.responses.tool import Tool
from pydantic import ( from pydantic import (
BaseModel, BaseModel,
Field, Field,
...@@ -84,6 +93,7 @@ class UsageInfo(BaseModel): ...@@ -84,6 +93,7 @@ class UsageInfo(BaseModel):
completion_tokens: Optional[int] = 0 completion_tokens: Optional[int] = 0
# only used to return cached tokens when --enable-cache-report is set # only used to return cached tokens when --enable-cache-report is set
prompt_tokens_details: Optional[Dict[str, int]] = None prompt_tokens_details: Optional[Dict[str, int]] = None
reasoning_tokens: Optional[int] = 0
class StreamOptions(BaseModel): class StreamOptions(BaseModel):
...@@ -428,6 +438,13 @@ class ChatCompletionRequest(BaseModel): ...@@ -428,6 +438,13 @@ class ChatCompletionRequest(BaseModel):
default="auto", examples=["none"] default="auto", examples=["none"]
) # noqa ) # noqa
return_hidden_states: bool = False return_hidden_states: bool = False
reasoning_effort: Optional[Literal["low", "medium", "high"]] = Field(
default="medium",
description="Constrains effort on reasoning for reasoning models. "
"'low' is the least effort, 'high' is the most effort. Reducing reasoning effort can "
"result in faster responses and fewer tokens used on reasoning in a response. "
"Currently only supported for OpenAI models.",
)
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
...@@ -619,6 +636,196 @@ OpenAIServingRequest = Union[ ...@@ -619,6 +636,196 @@ OpenAIServingRequest = Union[
] ]
# Response API protocol definitions
class ResponseReasoningParam(BaseModel):
"""Reasoning parameters for responses."""
effort: Optional[Literal["low", "medium", "high"]] = Field(
default="medium",
description="Constrains effort on reasoning for reasoning models.",
)
class ResponseTool(BaseModel):
"""Tool definition for responses."""
type: Literal["web_search_preview", "code_interpreter"] = Field(
description="Type of tool to enable"
)
ResponseInputOutputItem: TypeAlias = Union[
ResponseInputItemParam,
"ResponseReasoningItem",
ResponseFunctionToolCall,
]
class ResponsesRequest(BaseModel):
"""Request body for v1/responses endpoint."""
# Core OpenAI API fields (ordered by official documentation)
background: Optional[bool] = False
include: Optional[
List[
Literal[
"code_interpreter_call.outputs",
"computer_call_output.output.image_url",
"file_search_call.results",
"message.input_image.image_url",
"message.output_text.logprobs",
"reasoning.encrypted_content",
]
]
] = None
input: Union[str, List[ResponseInputOutputItem]]
instructions: Optional[str] = None
max_output_tokens: Optional[int] = None
max_tool_calls: Optional[int] = None
metadata: Optional[Dict[str, Any]] = None
model: Optional[str] = None # Made optional to match vLLM
parallel_tool_calls: Optional[bool] = True
previous_response_id: Optional[str] = None
reasoning: Optional[ResponseReasoningParam] = None
service_tier: Literal["auto", "default", "flex", "scale", "priority"] = "auto"
store: Optional[bool] = True
stream: Optional[bool] = False
temperature: Optional[float] = None
tool_choice: Literal["auto", "required", "none"] = "auto"
tools: List[ResponseTool] = Field(default_factory=list)
top_logprobs: Optional[int] = 0
top_p: Optional[float] = None
truncation: Optional[Literal["auto", "disabled"]] = "disabled"
user: Optional[str] = None
# Extra SGLang parameters
request_id: str = Field(
default_factory=lambda: f"resp_{uuid.uuid4().hex}",
description="The request_id related to this request. If the caller does not set it, a random uuid will be generated.",
)
priority: int = Field(default=0, description="Request priority")
# SGLang-specific sampling parameters
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
stop: Optional[Union[str, List[str]]] = None
top_k: int = -1
min_p: float = 0.0
repetition_penalty: float = 1.0
# Default sampling parameters
_DEFAULT_SAMPLING_PARAMS = {
"temperature": 0.7,
"top_p": 1.0,
"top_k": -1,
"min_p": 0.0,
"repetition_penalty": 1.0,
}
def to_sampling_params(
self, default_max_tokens: int, default_params: Optional[Dict] = None
) -> Dict[str, Any]:
"""Convert to sampling parameters for generation."""
if default_params is None:
default_params = {}
# Use max_output_tokens if available, otherwise use max_tokens for backwards compatibility
if self.max_output_tokens is not None:
max_tokens = min(self.max_output_tokens, default_max_tokens)
else:
max_tokens = default_max_tokens
# Avoid exceed the context length by minus 1 token
max_tokens -= 1
# Get parameters with defaults
temperature = self.temperature
if temperature is None:
temperature = default_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
)
top_p = self.top_p
if top_p is None:
top_p = default_params.get("top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
params = {
"max_new_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"stop": self.stop,
"top_k": self.top_k,
"min_p": self.min_p,
"repetition_penalty": self.repetition_penalty,
}
# Apply any additional default parameters
for key, value in default_params.items():
if key not in params or params[key] is None:
params[key] = value
return params
class PromptTokenUsageInfo(BaseModel):
"""Prompt token usage details."""
cached_tokens: int = 0
class ResponsesResponse(BaseModel):
"""Response body for v1/responses endpoint."""
id: str = Field(default_factory=lambda: f"resp_{time.time()}")
object: Literal["response"] = "response"
created_at: int = Field(default_factory=lambda: int(time.time()))
model: str
output: List[
Union[ResponseOutputItem, ResponseReasoningItem, ResponseFunctionToolCall]
] = Field(default_factory=list)
status: Literal["queued", "in_progress", "completed", "failed", "cancelled"]
usage: Optional[UsageInfo] = None
parallel_tool_calls: bool = True
tool_choice: str = "auto"
tools: List[ResponseTool] = Field(default_factory=list)
@classmethod
def from_request(
cls,
request: ResponsesRequest,
sampling_params: Any,
model_name: str,
created_time: int,
output: List[
Union[ResponseOutputItem, ResponseReasoningItem, ResponseFunctionToolCall]
],
status: str,
usage: Optional[UsageInfo],
) -> "ResponsesResponse":
"""Create a response from a request."""
return cls(
id=request.request_id,
created_at=created_time,
model=model_name,
output=output,
status=status,
usage=usage,
parallel_tool_calls=request.parallel_tool_calls or True,
tool_choice=request.tool_choice,
tools=request.tools,
)
class RequestResponseMetadata(BaseModel):
"""Metadata for request/response tracking."""
request_id: str
final_usage_info: Optional[UsageInfo] = None
@dataclass @dataclass
class MessageProcessingResult: class MessageProcessingResult:
"""Result of processing chat messages and applying templates. """Result of processing chat messages and applying templates.
...@@ -645,3 +852,22 @@ class MessageProcessingResult: ...@@ -645,3 +852,22 @@ class MessageProcessingResult:
modalities: List[str] modalities: List[str]
stop: List[str] stop: List[str]
tool_call_constraint: Optional[Any] = None tool_call_constraint: Optional[Any] = None
class ResponseReasoningTextContent(BaseModel):
text: str
type: Literal["reasoning_text"] = "reasoning_text"
class ResponseReasoningItem(BaseModel):
id: str
content: list[ResponseReasoningTextContent] = Field(default_factory=list)
summary: list = Field(default_factory=list)
type: Literal["reasoning"] = "reasoning"
encrypted_content: Optional[str] = None
status: Optional[Literal["in_progress", "completed", "incomplete"]]
ResponseInputOutputItem: TypeAlias = Union[
ResponseInputItemParam, "ResponseReasoningItem", ResponseFunctionToolCall
]
...@@ -7,8 +7,18 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Union ...@@ -7,8 +7,18 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from fastapi import Request from fastapi import Request
from fastapi.responses import ORJSONResponse, StreamingResponse from fastapi.responses import ORJSONResponse, StreamingResponse
from openai_harmony import Message as OpenAIMessage
from sglang.srt.conversation import generate_chat_conv from sglang.srt.conversation import generate_chat_conv
from sglang.srt.entrypoints.harmony_utils import (
get_developer_message,
get_stop_tokens_for_assistant_actions,
get_streamable_parser_for_assistant,
get_system_message,
parse_chat_input,
parse_output_into_messages,
render_for_completion,
)
from sglang.srt.entrypoints.openai.protocol import ( from sglang.srt.entrypoints.openai.protocol import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
...@@ -51,6 +61,26 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -51,6 +61,26 @@ class OpenAIServingChat(OpenAIServingBase):
): ):
super().__init__(tokenizer_manager) super().__init__(tokenizer_manager)
self.template_manager = template_manager self.template_manager = template_manager
self.use_harmony = (
self.tokenizer_manager.model_config.hf_config.model_type == "gpt_oss"
)
if self.use_harmony:
from sglang.srt.function_call.harmony_tool_parser import (
HarmonyToolCallParser,
)
self.harmony_tool_parser = HarmonyToolCallParser()
# NOTE While OpenAI's chat completion API supports browsing
# for some models, currently vLLM doesn't support it. Please use the
# Responses API instead.
self.supports_browsing = False
self.browser_tool = None
# NOTE: Chat completion API does not support code interpreter.
# Please use the Responses API instead.
self.supports_code_interpreter = False
self.python_tool = None
def _request_id_prefix(self) -> str: def _request_id_prefix(self) -> str:
return "chatcmpl-" return "chatcmpl-"
...@@ -77,41 +107,66 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -77,41 +107,66 @@ class OpenAIServingChat(OpenAIServingBase):
is_multimodal = self.tokenizer_manager.model_config.is_multimodal is_multimodal = self.tokenizer_manager.model_config.is_multimodal
# Process messages and apply chat template # Process messages and apply chat template
processed_messages = self._process_messages(request, is_multimodal) if not self.use_harmony:
processed_messages = self._process_messages(request, is_multimodal)
# Build sampling parameters
sampling_params = self._build_sampling_params( # Build sampling parameters
request, processed_messages.stop, processed_messages.tool_call_constraint sampling_params = self._build_sampling_params(
) request,
processed_messages.stop,
processed_messages.tool_call_constraint,
)
# Handle single vs multiple requests # Handle single vs multiple requests
if is_multimodal: if is_multimodal:
prompt_kwargs = {"text": processed_messages.prompt} prompt_kwargs = {"text": processed_messages.prompt}
else:
if isinstance(processed_messages.prompt_ids, str):
prompt_kwargs = {"text": processed_messages.prompt_ids}
else: else:
prompt_kwargs = {"input_ids": processed_messages.prompt_ids} if isinstance(processed_messages.prompt_ids, str):
prompt_kwargs = {"text": processed_messages.prompt_ids}
adapted_request = GenerateReqInput( else:
**prompt_kwargs, prompt_kwargs = {"input_ids": processed_messages.prompt_ids}
image_data=processed_messages.image_data,
video_data=processed_messages.video_data, adapted_request = GenerateReqInput(
audio_data=processed_messages.audio_data, **prompt_kwargs,
sampling_params=sampling_params, image_data=processed_messages.image_data,
return_logprob=request.logprobs, video_data=processed_messages.video_data,
logprob_start_len=-1, audio_data=processed_messages.audio_data,
top_logprobs_num=request.top_logprobs or 0, sampling_params=sampling_params,
stream=request.stream, return_logprob=request.logprobs,
return_text_in_logprobs=True, logprob_start_len=-1,
modalities=processed_messages.modalities, top_logprobs_num=request.top_logprobs or 0,
lora_path=request.lora_path, stream=request.stream,
bootstrap_host=request.bootstrap_host, return_text_in_logprobs=True,
bootstrap_port=request.bootstrap_port, modalities=processed_messages.modalities,
bootstrap_room=request.bootstrap_room, lora_path=request.lora_path,
return_hidden_states=request.return_hidden_states, bootstrap_host=request.bootstrap_host,
rid=request.rid, bootstrap_port=request.bootstrap_port,
) bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states,
rid=request.rid,
)
else:
processed_messages, prompt_ids = self._make_request_with_harmony(request)
adapted_request = GenerateReqInput(
input_ids=prompt_ids,
sampling_params=self._build_sampling_params(
request,
request.stop,
tool_call_constraint=None,
),
stream=request.stream,
return_logprob=request.logprobs,
logprob_start_len=-1,
top_logprobs_num=request.top_logprobs or 0,
return_text_in_logprobs=True,
lora_path=request.lora_path,
bootstrap_host=request.bootstrap_host,
bootstrap_port=request.bootstrap_port,
bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states,
rid=request.rid,
)
return adapted_request, request return adapted_request, request
...@@ -402,6 +457,12 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -402,6 +457,12 @@ class OpenAIServingChat(OpenAIServingBase):
cached_tokens = {} cached_tokens = {}
hidden_states = {} hidden_states = {}
# Harmony tracking
if self.use_harmony:
harmony_parsers = [
get_streamable_parser_for_assistant() for _ in range(request.n)
]
try: try:
async for content in self.tokenizer_manager.generate_request( async for content in self.tokenizer_manager.generate_request(
adapted_request, raw_request adapted_request, raw_request
...@@ -449,14 +510,57 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -449,14 +510,57 @@ class OpenAIServingChat(OpenAIServingBase):
yield f"data: {chunk.model_dump_json()}\n\n" yield f"data: {chunk.model_dump_json()}\n\n"
# Process content delta # Process content delta
stream_buffer = stream_buffers.get(index, "") if self.use_harmony:
delta = content["text"][len(stream_buffer) :] harmony_parser = harmony_parsers[index]
stream_buffers[index] = stream_buffer + delta
new_token_ids = content["output_ids"]
for token_id in new_token_ids:
harmony_parser.process(token_id)
is_final = harmony_parser.current_channel == "final"
is_analysis = harmony_parser.current_channel == "analysis"
delta = harmony_parser.last_content_delta or ""
if is_analysis:
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(reasoning_content=delta),
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=int(time.time()),
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n"
continue
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(content=delta if delta else None),
finish_reason=None,
matched_stop=None,
logprobs=choice_logprobs,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=int(time.time()),
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n"
continue
else:
stream_buffer = stream_buffers.get(index, "")
delta = content["text"][len(stream_buffer) :]
stream_buffers[index] = stream_buffer + delta
# Handle reasoning content # Handle reasoning content
if ( if (
self.tokenizer_manager.server_args.reasoning_parser self.tokenizer_manager.server_args.reasoning_parser
and request.separate_reasoning and request.separate_reasoning
and not self.use_harmony
): ):
reasoning_text, delta = self._process_reasoning_stream( reasoning_text, delta = self._process_reasoning_stream(
index, delta, reasoning_parser_dict, content, request index, delta, reasoning_parser_dict, content, request
...@@ -475,8 +579,27 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -475,8 +579,27 @@ class OpenAIServingChat(OpenAIServingBase):
) )
yield f"data: {chunk.model_dump_json()}\n\n" yield f"data: {chunk.model_dump_json()}\n\n"
if self.use_harmony and not is_final:
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(reasoning_content=delta),
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=int(time.time()),
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n"
# Handle tool calls # Handle tool calls
if request.tool_choice != "none" and request.tools: # TODO: support tool call parsing for harmony
if (
request.tool_choice != "none"
and request.tools
and not self.use_harmony
):
async for chunk in self._process_tool_call_stream( async for chunk in self._process_tool_call_stream(
index, index,
delta, delta,
...@@ -502,7 +625,7 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -502,7 +625,7 @@ class OpenAIServingChat(OpenAIServingBase):
if delta: if delta:
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=index, index=index,
delta=DeltaMessage(content=delta if delta else None), delta=DeltaMessage(content=delta),
finish_reason=None, finish_reason=None,
matched_stop=None, matched_stop=None,
logprobs=choice_logprobs, logprobs=choice_logprobs,
...@@ -640,6 +763,76 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -640,6 +763,76 @@ class OpenAIServingChat(OpenAIServingBase):
finish_reason = ret_item["meta_info"]["finish_reason"] finish_reason = ret_item["meta_info"]["finish_reason"]
text = ret_item["text"] text = ret_item["text"]
output_ids = ret_item["output_ids"]
if self.use_harmony:
parser = parse_output_into_messages(output_ids)
output_msgs = parser.messages
if len(output_msgs) == 0:
# The generation has stopped during reasoning.
is_tool_call = False
reasoning_content = parser.current_content
final_content = None
elif len(output_msgs) == 1:
# The generation has stopped during final message.
is_tool_call = False
reasoning_content = output_msgs[0].content[0].text
final_content = parser.current_content
else:
if len(output_msgs) != 2:
raise ValueError(
"Expected 2 output messages (reasoning and final), "
f"but got {len(output_msgs)}."
)
reasoning_msg, final_msg = output_msgs
reasoning_content = reasoning_msg.content[0].text
final_content = final_msg.content[0].text
is_tool_call = final_msg.recipient is not None
if is_tool_call:
# Extract tool call information from final message
tool_call = (
self.harmony_tool_parser.extract_tool_calls_from_message(
final_msg
)
)
tool_calls = [tool_call] if tool_call else []
message = ChatMessage(
role="assistant",
reasoning_content=reasoning_content,
content=None, # Tool calls don't have regular content
tool_calls=tool_calls,
)
else:
# Normal message
message = ChatMessage(
role="assistant",
reasoning_content=reasoning_content,
content=final_content,
)
if is_tool_call:
finish_reason_type = "tool_calls"
elif finish_reason:
finish_reason_type = (
finish_reason["type"] if finish_reason else "stop"
)
else:
finish_reason_type = "stop"
choice_data = ChatCompletionResponseChoice(
index=idx,
message=message,
logprobs=choice_logprobs,
finish_reason=finish_reason_type,
matched_stop=(
finish_reason["matched"]
if finish_reason and "matched" in finish_reason
else None
),
)
choices.append(choice_data)
continue
# Handle reasoning content # Handle reasoning content
reasoning_text = None reasoning_text = None
...@@ -978,3 +1171,33 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -978,3 +1171,33 @@ class OpenAIServingChat(OpenAIServingBase):
return f"data: {chunk.model_dump_json()}\n\n" return f"data: {chunk.model_dump_json()}\n\n"
return None return None
def _make_request_with_harmony(
self,
request: ChatCompletionRequest,
):
messages: list[OpenAIMessage] = []
# Add system message.
# In Chat Completion API, browsing is enabled by default if the model
# supports it.
assert not self.supports_browsing
assert not self.supports_code_interpreter
sys_msg = get_system_message(
reasoning_effort=request.reasoning_effort,
browser_description=None,
python_description=None,
)
messages.append(sys_msg)
# Add developer message.
dev_msg = get_developer_message()
messages.append(dev_msg)
# Add user message.
for chat_msg in request.messages:
messages.append(parse_chat_input(chat_msg))
# Render prompt token ids.
prompt_token_ids = render_for_completion(messages)
return messages, prompt_token_ids
# SPDX-License-Identifier: Apache-2.0
# Adapted from vLLM's OpenAIServingResponses
"""Handler for /v1/responses requests"""
import asyncio
import copy
import json
import logging
import time
from contextlib import AsyncExitStack
from http import HTTPStatus
from typing import Any, AsyncGenerator, AsyncIterator, Optional, Union
import jinja2
import openai.types.responses as openai_responses_types
from fastapi import Request
from fastapi.responses import ORJSONResponse
from openai.types.responses import (
ResponseOutputMessage,
ResponseOutputText,
ResponseReasoningItem,
)
from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall
from openai.types.responses.response_reasoning_item import (
Content as ResponseReasoningTextContent,
)
from openai_harmony import Message as OpenAIMessage
from sglang.srt.entrypoints.context import (
ConversationContext,
HarmonyContext,
SimpleContext,
StreamingHarmonyContext,
)
from sglang.srt.entrypoints.harmony_utils import (
get_developer_message,
get_stop_tokens_for_assistant_actions,
get_system_message,
get_user_message,
parse_output_message,
parse_remaining_state,
parse_response_input,
render_for_completion,
)
from sglang.srt.entrypoints.openai.protocol import (
ChatCompletionMessageParam,
ChatCompletionRequest,
PromptTokenUsageInfo,
RequestResponseMetadata,
ResponsesRequest,
ResponsesResponse,
UsageInfo,
)
from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
from sglang.srt.entrypoints.openai.tool_server import MCPToolServer, ToolServer
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.reasoning_parser import ReasoningParser
from sglang.srt.utils import random_uuid
logger = logging.getLogger(__name__)
class OpenAIServingResponses(OpenAIServingChat):
"""Handler for /v1/responses requests"""
def __init__(
self,
tokenizer_manager: TokenizerManager,
template_manager: TemplateManager,
*,
enable_prompt_tokens_details: bool = False,
enable_force_include_usage: bool = False,
tool_server: Optional[ToolServer] = None,
) -> None:
super().__init__(tokenizer_manager, template_manager)
# template_manager is already set by parent class
self.reasoning_parser = self.tokenizer_manager.server_args.reasoning_parser
self.enable_prompt_tokens_details = enable_prompt_tokens_details
self.enable_force_include_usage = enable_force_include_usage
# Get default sampling params from model config if available
self.default_sampling_params = {}
self.supports_browsing = (
tool_server.has_tool("browser") if tool_server else False
)
self.supports_code_interpreter = (
tool_server.has_tool("python") if tool_server else False
)
self.tool_server = tool_server
# Get from model config
self.use_harmony = (
self.tokenizer_manager.model_config.hf_config.model_type == "gpt_oss"
)
if self.use_harmony:
# OpenAI models have two EOS-like tokens: <|return|> and <|call|>.
# We need to add them to the stop token ids.
if "stop_token_ids" not in self.default_sampling_params:
self.default_sampling_params["stop_token_ids"] = []
self.default_sampling_params["stop_token_ids"].extend(
get_stop_tokens_for_assistant_actions()
)
# Response storage for background and retrieval operations
# Note: In production, this should use a proper storage backend (Redis, database)
# with TTL/expiration to prevent memory leaks
self.response_store: dict[str, ResponsesResponse] = {}
self.response_store_lock = asyncio.Lock()
# Message storage for conversation continuity
# Note: In production, this should use a proper storage backend (Redis, database)
# with TTL/expiration to prevent memory leaks
self.msg_store: dict[
str, Union[list[ChatCompletionMessageParam], list["OpenAIMessage"]]
] = {}
self.background_tasks: dict[str, asyncio.Task] = {}
def _request_id_prefix(self) -> str:
return "resp_"
async def create_responses(
self,
request: ResponsesRequest,
raw_request: Optional[Request] = None,
) -> Union[AsyncGenerator[str, None], ResponsesResponse, ORJSONResponse]:
# Validate model
if not self.tokenizer_manager:
return self.create_error_response("Model not loaded")
# FIXME: If the engine is dead, raise an error
# This is required for the streaming case
# Handle the previous response ID
prev_response_id = request.previous_response_id
if prev_response_id is not None:
if not prev_response_id.startswith("resp_"):
return self._make_invalid_id_error(prev_response_id)
async with self.response_store_lock:
prev_response = self.response_store.get(prev_response_id)
if prev_response is None:
return self._make_not_found_error(prev_response_id)
else:
prev_response = None
try:
model_name = request.model
tokenizer = self.tokenizer_manager.tokenizer
if self.use_harmony:
messages, request_prompts, engine_prompts = (
self._make_request_with_harmony(request, prev_response)
)
else:
messages, request_prompts, engine_prompts = await self._make_request(
request, prev_response, tokenizer
)
except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(f"{e} {e.__cause__}")
request_metadata = RequestResponseMetadata(request_id=request.request_id)
if raw_request:
raw_request.state.request_metadata = request_metadata
if (
self.tool_server is not None
and isinstance(self.tool_server, MCPToolServer)
and (request.background or request.stream)
and request.tools
and any(
tool.type in ["web_search_preview", "code_interpreter"]
for tool in request.tools
)
):
return self.create_error_response(
"MCP tool server is not supported in background mode and "
"streaming mode"
)
# Schedule the request and get the result generator
generators: list[AsyncGenerator[Any, None]] = []
tool_list = []
if self.use_harmony:
if self.supports_browsing:
tool_list.append("browser")
if self.supports_code_interpreter:
tool_list.append("python")
async with AsyncExitStack() as exit_stack:
try:
if self.tool_server is not None:
tool_session_ctxs: dict[str, Any] = {
tool_name: exit_stack.enter_async_context(
self.tool_server.get_tool_session(tool_name)
)
for tool_name in tool_list
}
tool_sessions = {}
for tool_name in tool_list:
tool_sessions[tool_name] = await tool_session_ctxs[tool_name]
else:
assert len(tool_list) == 0
tool_sessions = {}
for i, engine_prompt in enumerate(engine_prompts):
# Calculate default max tokens from context length minus prompt length
if hasattr(engine_prompt, "__len__"):
prompt_length = len(engine_prompt)
elif isinstance(engine_prompt, list):
prompt_length = len(engine_prompt)
else:
prompt_length = 0
context_len = (
self.tokenizer_manager.model_config.context_len
if hasattr(self.tokenizer_manager.model_config, "context_len")
else 4096
)
default_max_tokens = max(
context_len - prompt_length, 512
) # Ensure minimum 512 tokens
sampling_params = request.to_sampling_params(
default_max_tokens, self.default_sampling_params
)
context: ConversationContext
if self.use_harmony:
if request.stream:
context = StreamingHarmonyContext(messages, tool_sessions)
else:
context = HarmonyContext(messages, tool_sessions)
else:
context = SimpleContext()
# Create GenerateReqInput for SGLang
adapted_request = GenerateReqInput(
input_ids=engine_prompt,
sampling_params=sampling_params,
stream=request.stream,
rid=request.request_id,
background=request.background,
)
generator = self._generate_with_builtin_tools(
request.request_id,
request_prompts[i],
adapted_request,
sampling_params,
context,
raw_request=raw_request,
priority=request.priority,
)
generators.append(generator)
except ValueError as e:
return self.create_error_response(str(e))
assert len(generators) == 1
(result_generator,) = generators
# Store the input messages
if request.store:
self.msg_store[request.request_id] = messages
if request.background:
created_time = int(time.time())
response = ResponsesResponse.from_request(
request,
sampling_params,
model_name=model_name,
created_time=created_time,
output=[],
status="queued",
usage=None,
)
async with self.response_store_lock:
self.response_store[response.id] = response
# Run the request in the background
task = asyncio.create_task(
self._run_background_request(
request,
sampling_params,
result_generator,
context,
model_name,
tokenizer,
request_metadata,
created_time,
),
name=f"create_{response.id}",
)
# For cleanup
self.background_tasks[response.id] = task
task.add_done_callback(
lambda _: self.background_tasks.pop(response.id, None)
)
return response
if request.stream:
return self.responses_stream_generator(
request,
sampling_params,
result_generator,
context,
model_name,
tokenizer,
request_metadata,
)
try:
result: Union[ORJSONResponse, ResponsesResponse] = (
await self.responses_full_generator(
request,
sampling_params,
result_generator,
context,
model_name,
tokenizer,
request_metadata,
)
)
return result
except Exception as e:
return self.create_error_response(str(e))
return self.create_error_response("Unknown error")
async def _make_request(
self,
request: ResponsesRequest,
prev_response: Optional[ResponsesResponse],
tokenizer: Any,
):
# Construct the input messages
messages = self._construct_input_messages(request, prev_response)
# Follow SGLang's pattern: create a ChatCompletionRequest and process messages
try:
# Convert ResponsesRequest to ChatCompletionRequest for processing
chat_request = ChatCompletionRequest(
model=request.model,
messages=messages,
stream=request.stream,
)
# Follow SGLang's _process_messages pattern
is_multimodal = self.tokenizer_manager.model_config.is_multimodal
processed_messages = self._process_messages(chat_request, is_multimodal)
# Extract the results
if is_multimodal:
request_prompts = [processed_messages.prompt]
engine_prompts = [processed_messages.prompt]
else:
request_prompts = [processed_messages.prompt_ids]
engine_prompts = [processed_messages.prompt_ids]
except Exception as e:
logger.warning(f"Chat processing failed, using fallback: {e}")
# Fallback to simple encoding
prompt_text = ""
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
prompt_text += f"{role}: {content}\n"
prompt_ids = tokenizer.encode(prompt_text)
request_prompts = [prompt_ids]
engine_prompts = [prompt_ids]
return messages, request_prompts, engine_prompts
def _make_request_with_harmony(
self,
request: ResponsesRequest,
prev_response: Optional[ResponsesResponse],
):
if request.tool_choice != "auto":
raise NotImplementedError(
"Only 'auto' tool_choice is supported in " "response API"
)
messages = self._construct_input_messages_with_harmony(request, prev_response)
prompt_token_ids = render_for_completion(messages)
engine_prompt = prompt_token_ids
return messages, [prompt_token_ids], [engine_prompt]
async def responses_full_generator(
self,
request: ResponsesRequest,
sampling_params: Any,
result_generator: AsyncIterator[Any],
context: ConversationContext,
model_name: str,
tokenizer: Any,
request_metadata: RequestResponseMetadata,
created_time: Optional[int] = None,
) -> Union[ResponsesResponse, ORJSONResponse]:
if created_time is None:
created_time = int(time.time())
try:
async for _ in result_generator:
pass
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
return self.create_error_response(str(e))
if self.use_harmony:
assert isinstance(context, HarmonyContext)
output = self._make_response_output_items_with_harmony(context)
# TODO: these are all 0 for now!
num_prompt_tokens = context.num_prompt_tokens
num_generated_tokens = context.num_output_tokens
num_cached_tokens = context.num_cached_tokens
num_reasoning_tokens = context.num_reasoning_tokens
else:
assert isinstance(context, SimpleContext)
final_res = context.last_output
assert final_res is not None
output = self._make_response_output_items(
request, final_res["text"], tokenizer
)
# Calculate usage from actual output
if hasattr(final_res, "meta_info"):
num_prompt_tokens = final_res.meta_info.get("prompt_tokens", 0)
num_generated_tokens = final_res.meta_info.get("completion_tokens", 0)
num_cached_tokens = final_res.meta_info.get("cached_tokens", 0)
elif hasattr(final_res, "prompt_token_ids") and hasattr(
final_res, "outputs"
):
# Fallback calculation if meta_info not available
num_prompt_tokens = (
len(final_res.prompt_token_ids) if final_res.prompt_token_ids else 0
)
num_generated_tokens = (
len(final_res.outputs[0].token_ids)
if final_res.outputs and final_res.outputs[0].token_ids
else 0
)
num_cached_tokens = getattr(final_res, "num_cached_tokens", 0)
num_reasoning_tokens = 0
else:
# Final fallback
num_prompt_tokens = 0
num_generated_tokens = 0
num_cached_tokens = 0
num_reasoning_tokens = 0
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
reasoning_tokens=num_reasoning_tokens,
)
if self.enable_prompt_tokens_details and num_cached_tokens:
usage.prompt_tokens_details = PromptTokenUsageInfo(
cached_tokens=num_cached_tokens
)
request_metadata.final_usage_info = usage
response = ResponsesResponse.from_request(
request,
sampling_params,
model_name=model_name,
created_time=created_time,
output=output,
status="completed",
usage=usage,
)
if request.store:
async with self.response_store_lock:
stored_response = self.response_store.get(response.id)
# If the response is already cancelled, don't update it
if stored_response is None or stored_response.status != "cancelled":
self.response_store[response.id] = response
return response
def _make_response_output_items(
self,
request: ResponsesRequest,
final_output: Any,
tokenizer: Any,
):
# Handle reasoning parsing if enabled
if self.reasoning_parser:
# Use standard reasoning parser (openai maps to T4Detector internally)
reasoning_parser = ReasoningParser(
model_type=self.reasoning_parser, stream_reasoning=False
)
reasoning_content, content = reasoning_parser.parse_non_stream(final_output)
else:
reasoning_content = None
content = final_output
output_items = []
if reasoning_content:
reasoning_item = ResponseReasoningItem(
id=f"rs_{random_uuid()}",
type="reasoning",
summary=[],
content=[
ResponseReasoningTextContent(
type="reasoning_text", text=reasoning_content
),
],
status=None,
)
output_items.append(reasoning_item)
if content:
output_text = ResponseOutputText(
text=content,
annotations=[], # TODO
type="output_text",
logprobs=None, # TODO
)
message = ResponseOutputMessage(
id=f"msg_{random_uuid()}",
content=[output_text],
role="assistant",
status="completed",
type="message",
)
output_items.append(message)
return output_items
def _make_response_output_items_with_harmony(
self,
context: HarmonyContext,
):
output_items = []
num_init_messages = context.num_init_messages
for msg in context.messages[num_init_messages:]:
output_items.extend(parse_output_message(msg))
# Handle the generation stopped in the middle (if any).
last_items = parse_remaining_state(context.parser)
if last_items:
output_items.extend(last_items)
return output_items
def _construct_input_messages(
self,
request: ResponsesRequest,
prev_response: Optional[ResponsesResponse] = None,
) -> list[ChatCompletionMessageParam]:
messages: list[ChatCompletionMessageParam] = []
if request.instructions:
messages.append(
{
"role": "system",
"content": request.instructions,
}
)
# Prepend the conversation history
if prev_response is not None:
# Add the previous messages
prev_msg = self.msg_store[prev_response.id]
messages.extend(prev_msg)
# Add the previous output
for output_item in prev_response.output:
# NOTE: We skip the reasoning output of the previous response
if isinstance(output_item, ResponseReasoningItem):
continue
for content in output_item.content:
messages.append(
{
"role": "system",
"content": request.instructions,
}
)
# Append the new input
# Responses API supports simple text inputs without chat format
if isinstance(request.input, str):
messages.append({"role": "user", "content": request.input})
else:
messages.extend(request.input) # type: ignore
return messages
def _construct_input_messages_with_harmony(
self,
request: ResponsesRequest,
prev_response: Optional[ResponsesResponse],
) -> list["OpenAIMessage"]:
messages: list["OpenAIMessage"] = []
if prev_response is None:
# New conversation.
reasoning_effort = request.reasoning.effort if request.reasoning else None
tool_types = [tool.type for tool in request.tools]
enable_browser = (
"web_search_preview" in tool_types and self.tool_server is not None
)
enable_code_interpreter = (
"code_interpreter" in tool_types and self.tool_server is not None
)
sys_msg = get_system_message(
reasoning_effort=reasoning_effort,
browser_description=(
self.tool_server.get_tool_description("browser")
if self.tool_server and enable_browser
else None
),
python_description=(
self.tool_server.get_tool_description("python")
if self.tool_server and enable_code_interpreter
else None
),
)
messages.append(sys_msg)
dev_msg = get_developer_message(request.instructions, request.tools)
messages.append(dev_msg)
else:
# Continue the previous conversation.
# FIXME: Currently, request params like reasoning and
# instructions are ignored.
prev_msgs = self.msg_store[prev_response.id]
# Remove the previous chain-of-thoughts if there is a new "final"
# message.
if (
len(prev_msgs) > 0
and hasattr(prev_msgs[-1], "channel")
and prev_msgs[-1].channel == "final"
): # type: ignore[union-attr]
prev_final_msg_idx = -1
for i in range(len(prev_msgs) - 2, -1, -1):
if (
hasattr(prev_msgs[i], "channel")
and prev_msgs[i].channel == "final"
): # type: ignore[union-attr]
prev_final_msg_idx = i
break
recent_turn_msgs = prev_msgs[prev_final_msg_idx + 1 :]
del prev_msgs[prev_final_msg_idx + 1 :]
for msg in recent_turn_msgs:
if (
hasattr(msg, "channel") and msg.channel != "analysis"
): # type: ignore[union-attr]
prev_msgs.append(msg)
messages.extend(prev_msgs)
# Append the new input.
# Responses API supports simple text inputs without chat format.
if isinstance(request.input, str):
messages.append(get_user_message(request.input))
else:
if prev_response is not None:
prev_outputs = copy(prev_response.output)
else:
prev_outputs = []
for response_msg in request.input:
messages.append(parse_response_input(response_msg, prev_outputs))
if isinstance(response_msg, ResponseFunctionToolCall):
prev_outputs.append(response_msg)
return messages
async def _run_background_request(
self,
request: ResponsesRequest,
sampling_params: Any,
result_generator: AsyncIterator[Any],
context: ConversationContext,
model_name: str,
tokenizer: Any,
request_metadata: RequestResponseMetadata,
created_time: Optional[int] = None,
*args,
**kwargs,
):
try:
# Update the status to "in_progress"
async with self.response_store_lock:
stored_response = self.response_store.get(request.request_id)
assert stored_response is not None
stored_response.status = "in_progress"
response = await self.responses_full_generator(
request,
sampling_params,
result_generator,
context,
model_name,
tokenizer,
request_metadata,
created_time,
*args,
**kwargs,
)
except Exception as e:
logger.exception("Background request failed for %s", request.request_id)
response = self.create_error_response(str(e))
if isinstance(response, ORJSONResponse):
# If the request has failed, update the status to "failed"
response_id = request.request_id
async with self.response_store_lock:
stored_response = self.response_store.get(response_id)
assert stored_response is not None
if stored_response.status not in ("completed", "cancelled"):
stored_response.status = "failed"
async def retrieve_responses(
self,
response_id: str,
) -> Union[ResponsesResponse, ORJSONResponse]:
if not response_id.startswith("resp_"):
return self._make_invalid_id_error(response_id)
async with self.response_store_lock:
response = self.response_store.get(response_id)
if response is None:
return self._make_not_found_error(response_id)
return response
async def cancel_responses(
self,
response_id: str,
) -> Union[ResponsesResponse, ORJSONResponse]:
if not response_id.startswith("resp_"):
return self._make_invalid_id_error(response_id)
async with self.response_store_lock:
response = self.response_store.get(response_id)
if response is None:
return self._make_not_found_error(response_id)
prev_status = response.status
if prev_status not in ("queued", "in_progress"):
return self.create_error_response(
err_type="invalid_request_error",
message="Cannot cancel a synchronous response.",
)
# Update the status to "cancelled"
response.status = "cancelled"
# Abort the request
if task := self.background_tasks.get(response_id):
task.cancel()
try:
await task
except asyncio.CancelledError:
logger.exception("Background task for %s was cancelled", response_id)
return response
def _make_invalid_id_error(self, response_id: str):
return self.create_error_response(
message=(
f"Invalid 'response_id': '{response_id}'. "
"Expected an ID that begins with 'resp'."
),
err_type="invalid_request_error",
param="response_id",
)
def _make_not_found_error(self, response_id: str):
return self.create_error_response(
message=f"Response with id '{response_id}' not found.",
err_type="invalid_request_error",
status_code=HTTPStatus.NOT_FOUND,
param="response_id",
)
async def responses_stream_generator(
self,
request: ResponsesRequest,
sampling_params: Any,
result_generator: AsyncIterator[StreamingHarmonyContext],
context: StreamingHarmonyContext,
model_name: str,
tokenizer: Any,
request_metadata: RequestResponseMetadata,
created_time: Optional[int] = None,
) -> AsyncGenerator[str, None]:
# TODO:
# 1. Handle disconnect
created_time = created_time or int(time.time())
sequence_number = 0
def _send_event(event):
nonlocal sequence_number
# Set sequence_number if the event has this attribute
if hasattr(event, "sequence_number"):
event.sequence_number = sequence_number
sequence_number += 1
# Get event type from the event's type field if it exists
event_type = getattr(event, "type", "unknown")
return (
f"event: {event_type}\n"
f"data: {event.model_dump_json(indent=None)}\n\n"
)
current_content_index = 0
current_output_index = 0
current_item_id = f"item_{random_uuid()}"
sent_output_item_added = False
initial_response = ResponsesResponse.from_request(
request,
sampling_params,
model_name=model_name,
created_time=created_time,
output=[],
status="in_progress",
usage=None,
).model_dump()
yield _send_event(
openai_responses_types.ResponseCreatedEvent(
type="response.created",
sequence_number=-1,
response=initial_response,
)
)
yield _send_event(
openai_responses_types.ResponseInProgressEvent(
type="response.in_progress",
sequence_number=-1,
response=initial_response,
)
)
async for ctx in result_generator:
if ctx.is_expecting_start():
current_output_index += 1
sent_output_item_added = False
if len(ctx.parser.messages) > 0:
previous_item = ctx.parser.messages[-1]
if previous_item.recipient is not None:
# Deal with tool call here
pass
elif previous_item.channel == "analysis":
reasoning_item = ResponseReasoningItem(
id=f"rs_{random_uuid()}",
type="reasoning",
summary=[],
content=[
ResponseReasoningTextContent(
text=previous_item.content[0].text,
type="reasoning_text",
),
],
status="completed",
)
yield _send_event(
openai_responses_types.ResponseReasoningTextDoneEvent(
type="response.reasoning_text.done",
item_id=current_item_id,
sequence_number=-1,
output_index=current_output_index,
content_index=current_content_index,
text=previous_item.content[0].text,
)
)
yield _send_event(
openai_responses_types.ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=current_output_index,
item=reasoning_item,
)
)
elif previous_item.channel == "final":
text_content = openai_responses_types.ResponseOutputText(
type="output_text",
text=previous_item.content[0].text,
annotations=[],
)
yield _send_event(
openai_responses_types.ResponseTextDoneEvent(
type="response.output_text.done",
sequence_number=-1,
output_index=current_output_index,
content_index=current_content_index,
text=previous_item.content[0].text,
logprobs=[],
item_id=current_item_id,
)
)
yield _send_event(
openai_responses_types.ResponseContentPartDoneEvent(
type="response.content_part.done",
sequence_number=-1,
item_id=current_item_id,
output_index=current_output_index,
content_index=current_content_index,
part=text_content,
)
)
yield _send_event(
openai_responses_types.ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=current_output_index,
item=openai_responses_types.ResponseOutputMessage(
id=current_item_id,
type="message",
role="assistant",
content=[text_content],
status="completed",
),
)
)
if ctx.parser.last_content_delta:
if (
ctx.parser.current_channel == "final"
and ctx.parser.current_recipient is None
):
if not sent_output_item_added:
sent_output_item_added = True
yield _send_event(
openai_responses_types.ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=current_output_index,
item=openai_responses_types.ResponseOutputMessage(
id=current_item_id,
type="message",
role="assistant",
content=[],
status="in_progress",
),
)
)
yield _send_event(
openai_responses_types.ResponseContentPartAddedEvent(
type="response.content_part.added",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
content_index=current_content_index,
part=openai_responses_types.ResponseOutputText(
type="output_text",
text="",
annotations=[],
logprobs=[],
),
)
)
yield _send_event(
openai_responses_types.ResponseTextDeltaEvent(
type="response.output_text.delta",
sequence_number=-1,
content_index=current_content_index,
output_index=current_output_index,
item_id=current_item_id,
delta=ctx.parser.last_content_delta,
# TODO, use logprobs from ctx.last_request_output
logprobs=[],
)
)
elif (
ctx.parser.current_channel == "analysis"
and ctx.parser.current_recipient is None
):
if not sent_output_item_added:
sent_output_item_added = True
yield _send_event(
openai_responses_types.ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=current_output_index,
item=openai_responses_types.ResponseReasoningItem(
type="reasoning",
id=current_item_id,
summary=[],
status="in_progress",
),
)
)
yield _send_event(
openai_responses_types.ResponseContentPartAddedEvent(
type="response.content_part.added",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
content_index=current_content_index,
# TODO: migrate this to
# ResponseReasoningTextContent for now
part=openai_responses_types.ResponseOutputText(
type="output_text",
text="",
annotations=[],
logprobs=[],
),
)
)
# TODO: migrate to OpenAI types once updated.
yield _send_event(
openai_responses_types.ResponseReasoningTextDeltaEvent(
type="response.reasoning_text.delta",
item_id=current_item_id,
output_index=current_output_index,
content_index=current_content_index,
delta=ctx.parser.last_content_delta,
sequence_number=-1,
)
)
if ctx.is_assistant_action_turn() and len(ctx.parser.messages) > 0:
previous_item = ctx.parser.messages[-1]
if (
self.supports_browsing
and previous_item.recipient is not None
and previous_item.recipient.startswith("browser.")
):
function_name = previous_item.recipient[len("browser.") :]
action = None
parsed_args = json.loads(previous_item.content[0].text)
if function_name == "search":
action = openai_responses_types.response_function_web_search.ActionSearch(
type="search",
query=parsed_args["query"],
)
elif function_name == "open":
action = openai_responses_types.response_function_web_search.ActionOpenPage(
type="open_page",
# TODO: translate to url
url=f"cursor:{parsed_args.get('cursor', '')}",
)
elif function_name == "find":
action = openai_responses_types.response_function_web_search.ActionFind(
type="find",
pattern=parsed_args["pattern"],
# TODO: translate to url
url=f"cursor:{parsed_args.get('cursor', '')}",
)
else:
raise ValueError(f"Unknown function name: {function_name}")
yield _send_event(
openai_responses_types.ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=current_output_index,
item=openai_responses_types.response_function_web_search.ResponseFunctionWebSearch(
# TODO: generate a unique id for web search call
type="web_search_call",
id=current_item_id,
action=action,
status="in_progress",
),
)
)
yield _send_event(
openai_responses_types.ResponseWebSearchCallInProgressEvent(
type="response.web_search_call.in_progress",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
)
)
yield _send_event(
openai_responses_types.ResponseWebSearchCallSearchingEvent(
type="response.web_search_call.searching",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
)
)
# enqueue
yield _send_event(
openai_responses_types.ResponseWebSearchCallCompletedEvent(
type="response.web_search_call.completed",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
)
)
yield _send_event(
openai_responses_types.ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=current_output_index,
item=openai_responses_types.ResponseFunctionWebSearch(
type="web_search_call",
id=current_item_id,
action=action,
status="completed",
),
)
)
if (
self.supports_code_interpreter
and previous_item.recipient is not None
and previous_item.recipient.startswith("python")
):
yield _send_event(
openai_responses_types.ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=current_output_index,
item=openai_responses_types.ResponseCodeInterpreterToolCallParam(
type="code_interpreter_call",
id=current_item_id,
code="",
container_id="auto",
outputs=[],
status="in_progress",
),
)
)
yield _send_event(
openai_responses_types.ResponseCodeInterpreterCallInProgressEvent(
type="response.code_interpreter_call.in_progress",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
)
)
# TODO: do we need to add delta event here?
yield _send_event(
openai_responses_types.ResponseCodeInterpreterCallCodeDoneEvent(
type="response.code_interpreter_call_code.done",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
code=previous_item.content[0].text,
)
)
yield _send_event(
openai_responses_types.ResponseCodeInterpreterCallInterpretingEvent(
type="response.code_interpreter_call.interpreting",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
)
)
yield _send_event(
openai_responses_types.ResponseCodeInterpreterCallCompletedEvent(
type="response.code_interpreter_call.completed",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
)
)
yield _send_event(
openai_responses_types.ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=current_output_index,
item=openai_responses_types.ResponseCodeInterpreterToolCallParam(
type="code_interpreter_call",
id=current_item_id,
code=previous_item.content[0].text,
container_id="auto",
# TODO: add outputs here
outputs=[],
status="completed",
),
)
)
async def empty_async_generator():
if False:
yield
final_response = await self.responses_full_generator(
request,
sampling_params,
empty_async_generator(),
context,
model_name,
tokenizer,
request_metadata,
created_time=created_time,
)
# Convert final_response to the format expected by ResponseCompletedEvent
response_dict = final_response.model_dump()
# Convert UsageInfo to ResponseUsage format
if response_dict.get("usage"):
usage_info = response_dict["usage"]
response_dict["usage"] = {
"input_tokens": usage_info.get("prompt_tokens", 0),
"input_tokens_details": {
"cached_tokens": usage_info.get("cached_tokens", 0)
},
"output_tokens": usage_info.get("completion_tokens", 0),
"output_tokens_details": {
"reasoning_tokens": usage_info.get("reasoning_tokens", 0)
},
"total_tokens": usage_info.get("total_tokens", 0),
}
yield _send_event(
openai_responses_types.ResponseCompletedEvent(
type="response.completed",
sequence_number=-1,
response=response_dict,
)
)
async def _generate_with_builtin_tools(
self,
request_id: str,
request_prompt: Any,
adapted_request: GenerateReqInput,
sampling_params: Any,
context: ConversationContext,
raw_request: Optional[Request] = None,
priority: Optional[int] = None,
**kwargs,
) -> AsyncGenerator[Any, None]:
"""Generate with builtin tool support for harmony-based models."""
orig_priority = priority or 0
while True:
# Generate using SGLang's tokenizer manager
generator = self.tokenizer_manager.generate_request(
adapted_request, raw_request
)
async for res in generator:
context.append_output(res)
# NOTE(woosuk): The stop condition is handled by the engine.
yield context
if not context.need_builtin_tool_call():
# The model did not ask for a tool call, so we're done.
break
# Call the tool and update the context with the result.
tool_output = await context.call_tool()
context.append_output(tool_output)
# Prepare for the next generation turn
# Render the updated conversation for the next completion
prompt_token_ids = context.render_for_completion()
# Update the adapted request with new prompt
adapted_request = GenerateReqInput(
input_ids=prompt_token_ids,
sampling_params=sampling_params,
stream=adapted_request.stream,
rid=request_id,
return_logprob=adapted_request.return_logprob,
logprob_start_len=adapted_request.logprob_start_len,
top_logprobs_num=adapted_request.top_logprobs_num,
return_text_in_logprobs=adapted_request.return_text_in_logprobs,
return_hidden_states=adapted_request.return_hidden_states,
background=adapted_request.background,
)
# Update sampling params with reduced max_tokens
if hasattr(sampling_params, "max_new_tokens") or isinstance(
sampling_params, dict
):
context_len = getattr(
self.tokenizer_manager.model_config, "context_len", 4096
)
remaining_tokens = context_len - len(prompt_token_ids) - 1
if isinstance(sampling_params, dict):
sampling_params["max_new_tokens"] = max(remaining_tokens, 1)
else:
sampling_params.max_new_tokens = max(remaining_tokens, 1)
# Slightly reduce priority for subsequent tool calls
priority = orig_priority - 1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
from abc import ABC, abstractmethod
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from typing import Any
logger = logging.getLogger(__name__)
try:
from mcp import ClientSession
from mcp.client.sse import sse_client
from mcp.types import ListToolsResult
except ImportError:
logger.warning("Ignoring mcp import error")
from openai_harmony import ToolDescription, ToolNamespaceConfig
async def list_server_and_tools(server_url: str):
async with sse_client(url=server_url) as streams, ClientSession(
*streams
) as session:
initialize_response = await session.initialize()
list_tools_response = await session.list_tools()
return initialize_response, list_tools_response
def trim_schema(schema: dict) -> dict:
# Turn JSON Schema from MCP generated into Harmony's variant.
if "title" in schema:
del schema["title"]
if "default" in schema and schema["default"] is None:
del schema["default"]
if "anyOf" in schema:
# Turn "anyOf": [{"type": "type-1"}, {"type": "type-2"}]
# into "type": ["type-1", "type-2"]
# if there's more than 1 types, also remove "null" type as Harmony will
# just ignore it
types = [
type_dict["type"]
for type_dict in schema["anyOf"]
if type_dict["type"] != "null"
]
schema["type"] = types
del schema["anyOf"]
if "properties" in schema:
schema["properties"] = {
k: trim_schema(v) for k, v in schema["properties"].items()
}
return schema
def post_process_tools_description(
list_tools_result: "ListToolsResult",
) -> "ListToolsResult":
# Adapt the MCP tool result for Harmony
for tool in list_tools_result.tools:
tool.inputSchema = trim_schema(tool.inputSchema)
# Some tools schema don't need to be part of the prompt (e.g. simple text
# in text out for Python)
list_tools_result.tools = [
tool
for tool in list_tools_result.tools
if getattr(tool.annotations, "include_in_prompt", True)
]
return list_tools_result
class ToolServer(ABC):
@abstractmethod
def has_tool(self, tool_name: str):
pass
@abstractmethod
def get_tool_description(self, tool_name: str):
pass
@abstractmethod
def get_tool_session(self, tool_name: str) -> AbstractAsyncContextManager[Any]: ...
class MCPToolServer(ToolServer):
def __init__(self):
self.harmony_tool_descriptions = {}
async def add_tool_server(self, server_url: str):
tool_urls = server_url.split(",")
self.harmony_tool_descriptions = {}
self.urls: dict[str, str] = {}
for url in tool_urls:
url = f"http://{url}/sse"
initialize_response, list_tools_response = await list_server_and_tools(url)
list_tools_response = post_process_tools_description(list_tools_response)
tool_from_mcp = ToolNamespaceConfig(
name=initialize_response.serverInfo.name,
description=initialize_response.instructions,
tools=[
ToolDescription.new(
name=tool.name,
description=tool.description,
parameters=tool.inputSchema,
)
for tool in list_tools_response.tools
],
)
self.harmony_tool_descriptions[tool_from_mcp.name] = tool_from_mcp
if tool_from_mcp.name not in self.urls:
self.urls[tool_from_mcp.name] = url
else:
logger.warning(
"Tool %s already exists. Ignoring duplicate tool server %s",
tool_from_mcp.name,
url,
)
def has_tool(self, tool_name: str):
return tool_name in self.harmony_tool_descriptions
def get_tool_description(self, tool_name: str):
return self.harmony_tool_descriptions.get(tool_name)
@asynccontextmanager
async def get_tool_session(self, tool_name: str):
url = self.urls.get(tool_name)
if url:
async with sse_client(url=url) as streams, ClientSession(
*streams
) as session:
await session.initialize()
yield session
else:
logger.warning("Tool %s not found", tool_name)
class DemoToolServer(ToolServer):
def __init__(self):
from sglang.srt.entrypoints.tool import (
HarmonyBrowserTool,
HarmonyPythonTool,
Tool,
)
self.tools: dict[str, Tool] = {}
browser_tool = HarmonyBrowserTool()
if browser_tool.enabled:
self.tools["browser"] = browser_tool
python_tool = HarmonyPythonTool()
if python_tool.enabled:
self.tools["python"] = python_tool
def has_tool(self, tool_name: str):
return tool_name in self.tools
def get_tool_description(self, tool_name: str):
if tool_name not in self.tools:
return None
if tool_name == "browser":
return ToolNamespaceConfig.browser()
elif tool_name == "python":
return ToolNamespaceConfig.python()
else:
raise ValueError(f"Unknown tool {tool_name}")
@asynccontextmanager
async def get_tool_session(self, tool_name: str):
yield self.tools[tool_name]
# SPDX-License-Identifier: Apache-2.0
import logging
import os
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
# Avoid circular import.
from sglang.srt.entrypoints.context import ConversationContext
logger = logging.getLogger(__name__)
class Tool(ABC):
@abstractmethod
async def get_result(self, context: "ConversationContext") -> Any:
pass
class HarmonyBrowserTool(Tool):
def __init__(self):
self.enabled = True
exa_api_key = os.getenv("EXA_API_KEY")
if not exa_api_key:
self.enabled = False
logger.warning_once("EXA_API_KEY is not set, browsing is disabled")
return
try:
from gpt_oss.tools.simple_browser import SimpleBrowserTool
from gpt_oss.tools.simple_browser.backend import ExaBackend
except ImportError:
self.enabled = False
logger.warning_once("gpt_oss is not installed, browsing is disabled")
return
browser_backend = ExaBackend(source="web", api_key=exa_api_key)
self.browser_tool = SimpleBrowserTool(backend=browser_backend)
logger.info_once("Browser tool initialized")
async def get_result(self, context: "ConversationContext") -> Any:
from sglang.srt.entrypoints.context import HarmonyContext
assert isinstance(context, HarmonyContext)
last_msg = context.messages[-1]
tool_output_msgs = []
async for msg in self.browser_tool.process(last_msg):
tool_output_msgs.append(msg)
return tool_output_msgs
@property
def tool_config(self) -> Any:
return self.browser_tool.tool_config
class HarmonyPythonTool(Tool):
def __init__(self):
self.enabled = True
try:
from gpt_oss.tools.python_docker.docker_tool import PythonTool
except ImportError:
self.enabled = False
logger.warning_once(
"gpt_oss is not installed, code interpreter is disabled"
)
return
self.python_tool = PythonTool()
logger.info_once("Code interpreter tool initialized")
async def get_result(self, context: "ConversationContext") -> Any:
from sglang.srt.entrypoints.context import HarmonyContext
assert isinstance(context, HarmonyContext)
last_msg = context.messages[-1]
tool_output_msgs = []
async for msg in self.python_tool.process(last_msg):
tool_output_msgs.append(msg)
return tool_output_msgs
@property
def tool_config(self) -> Any:
return self.python_tool.tool_config
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Harmony tool call parser for processing tool calls in harmony models."""
import uuid
from typing import List, Optional, Tuple
from sglang.srt.entrypoints.openai.protocol import (
ChatMessage,
FunctionResponse,
ToolCall,
)
class HarmonyToolCallParser:
"""Parser for extracting tool calls from harmony model outputs."""
def extract_tool_calls_from_message(self, msg) -> Optional[ToolCall]:
"""
Extract tool call from a single message if it's a tool call.
Args:
msg: The harmony message
Returns:
ToolCall if the message is a tool call, None otherwise
"""
if (
msg.channel == "commentary"
and msg.recipient
and msg.recipient.startswith("functions.")
):
function_name = msg.recipient.split(".")[-1]
arguments = msg.content[0].text if msg.content else "{}"
return ToolCall(
id=f"call_{uuid.uuid4().hex[:24]}",
function=FunctionResponse(
name=function_name,
arguments=arguments,
),
)
return None
def process_streaming_chunk(
self,
harmony_parser,
index: int,
tool_call_trackers: dict,
stream_buffers: dict,
) -> Tuple[Optional[dict], bool, Optional[str]]:
"""
Process a streaming chunk for tool calls.
Args:
harmony_parser: The harmony parser instance
index: The choice index
tool_call_trackers: Dict tracking tool calls per choice
stream_buffers: Dict for buffering content
Returns:
Tuple of (tool_call_data, is_tool_call, delta)
"""
# Check if we're in a tool call
is_tool_call = (
harmony_parser.current_channel == "commentary"
and harmony_parser.current_recipient
and harmony_parser.current_recipient.startswith("functions.")
)
delta = harmony_parser.last_content_delta or ""
tool_call_data = None
if is_tool_call:
# Handle tool call streaming
function_name = harmony_parser.current_recipient.split(".")[-1]
# Track tool call indices per choice
if index not in tool_call_trackers:
tool_call_trackers[index] = {"count": 0, "current_function": None}
# Check if we just started a new tool call
tool_call_tracker = tool_call_trackers[index]
if tool_call_tracker["current_function"] != function_name:
# New tool call started
tool_call_tracker["current_function"] = function_name
tool_call_index = tool_call_tracker["count"]
tool_call_tracker["count"] += 1
# Store the tool call index for this function
tool_call_key = f"{index}_{function_name}"
stream_buffers[tool_call_key] = {
"index": tool_call_index,
"content": "",
}
tool_call_data = {
"id": f"call_{uuid.uuid4().hex[:24]}",
"index": tool_call_index,
"function_name": function_name,
"arguments": delta,
"is_first_chunk": True,
}
else:
# Subsequent chunks for the same tool call
tool_call_key = f"{index}_{function_name}"
tool_call_index = stream_buffers[tool_call_key]["index"]
tool_call_data = {
"id": None,
"index": tool_call_index,
"function_name": None,
"arguments": delta,
"is_first_chunk": False,
}
stream_buffers[tool_call_key]["content"] += delta
return tool_call_data, is_tool_call, delta
...@@ -216,7 +216,7 @@ class DetokenizerManager: ...@@ -216,7 +216,7 @@ class DetokenizerManager:
rids=recv_obj.rids, rids=recv_obj.rids,
finished_reasons=recv_obj.finished_reasons, finished_reasons=recv_obj.finished_reasons,
output_strs=output_strs, output_strs=output_strs,
output_ids=None, output_ids=recv_obj.decode_ids,
prompt_tokens=recv_obj.prompt_tokens, prompt_tokens=recv_obj.prompt_tokens,
completion_tokens=recv_obj.completion_tokens, completion_tokens=recv_obj.completion_tokens,
cached_tokens=recv_obj.cached_tokens, cached_tokens=recv_obj.cached_tokens,
......
...@@ -126,6 +126,9 @@ class GenerateReqInput: ...@@ -126,6 +126,9 @@ class GenerateReqInput:
# For data parallel rank routing # For data parallel rank routing
data_parallel_rank: Optional[int] = None data_parallel_rank: Optional[int] = None
# For background responses (OpenAI responses API)
background: bool = False
def contains_mm_input(self) -> bool: def contains_mm_input(self) -> bool:
return ( return (
has_valid_data(self.image_data) has_valid_data(self.image_data)
...@@ -560,6 +563,9 @@ class EmbeddingReqInput: ...@@ -560,6 +563,9 @@ class EmbeddingReqInput:
# For cross-encoder requests # For cross-encoder requests
is_cross_encoder_request: bool = False is_cross_encoder_request: bool = False
# For background responses (OpenAI responses API)
background: bool = False
def normalize_batch_and_arguments(self): def normalize_batch_and_arguments(self):
# at least one of text, input_ids, or image should be provided # at least one of text, input_ids, or image should be provided
if self.text is None and self.input_ids is None and self.image_data is None: if self.text is None and self.input_ids is None and self.image_data is None:
......
...@@ -571,8 +571,7 @@ class SchedulerOutputProcessorMixin: ...@@ -571,8 +571,7 @@ class SchedulerOutputProcessorMixin:
req.send_decode_id_offset = len(decode_ids) req.send_decode_id_offset = len(decode_ids)
read_offsets.append(read_offset) read_offsets.append(read_offset)
if self.skip_tokenizer_init: output_ids.append(req.output_ids[send_token_offset:])
output_ids.append(req.output_ids[send_token_offset:])
req.send_token_offset = len(req.output_ids) req.send_token_offset = len(req.output_ids)
skip_special_tokens.append(req.sampling_params.skip_special_tokens) skip_special_tokens.append(req.sampling_params.skip_special_tokens)
spaces_between_special_tokens.append( spaces_between_special_tokens.append(
......
...@@ -750,7 +750,11 @@ class TokenizerManager: ...@@ -750,7 +750,11 @@ class TokenizerManager:
try: try:
await asyncio.wait_for(state.event.wait(), timeout=4) await asyncio.wait_for(state.event.wait(), timeout=4)
except asyncio.TimeoutError: except asyncio.TimeoutError:
if request is not None and await request.is_disconnected(): if (
request is not None
and not obj.background
and await request.is_disconnected()
):
# Abort the request for disconnected requests (non-streaming, waiting queue) # Abort the request for disconnected requests (non-streaming, waiting queue)
self.abort_request(obj.rid) self.abort_request(obj.rid)
# Use exception to kill the whole call stack and asyncio task # Use exception to kill the whole call stack and asyncio task
...@@ -805,7 +809,11 @@ class TokenizerManager: ...@@ -805,7 +809,11 @@ class TokenizerManager:
if obj.stream: if obj.stream:
yield out yield out
else: else:
if request is not None and await request.is_disconnected(): if (
request is not None
and not obj.background
and await request.is_disconnected()
):
# Abort the request for disconnected requests (non-streaming, running) # Abort the request for disconnected requests (non-streaming, running)
self.abort_request(obj.rid) self.abort_request(obj.rid)
# Use exception to kill the whole call stack and asyncio task # Use exception to kill the whole call stack and asyncio task
...@@ -1548,8 +1556,17 @@ class TokenizerManager: ...@@ -1548,8 +1556,17 @@ class TokenizerManager:
if isinstance(recv_obj, BatchStrOut): if isinstance(recv_obj, BatchStrOut):
state.text += recv_obj.output_strs[i] state.text += recv_obj.output_strs[i]
if state.obj.stream:
state.output_ids.extend(recv_obj.output_ids[i])
output_token_ids = state.output_ids[state.last_output_offset :]
state.last_output_offset = len(state.output_ids)
else:
state.output_ids.extend(recv_obj.output_ids[i])
output_token_ids = state.output_ids.copy()
out_dict = { out_dict = {
"text": state.text, "text": state.text,
"output_ids": output_token_ids,
"meta_info": meta_info, "meta_info": meta_info,
} }
elif isinstance(recv_obj, BatchTokenIDOut): elif isinstance(recv_obj, BatchTokenIDOut):
......
...@@ -274,6 +274,9 @@ class ServerArgs: ...@@ -274,6 +274,9 @@ class ServerArgs:
enable_pdmux: bool = False enable_pdmux: bool = False
sm_group_num: int = 3 sm_group_num: int = 3
# For tool server
tool_server: Optional[str] = None
# Deprecated arguments # Deprecated arguments
enable_ep_moe: bool = False enable_ep_moe: bool = False
enable_deepep_moe: bool = False enable_deepep_moe: bool = False
...@@ -1916,6 +1919,14 @@ class ServerArgs: ...@@ -1916,6 +1919,14 @@ class ServerArgs:
help="Disable mmap while loading weight using safetensors.", help="Disable mmap while loading weight using safetensors.",
) )
# For tool server
parser.add_argument(
"--tool-server",
type=str,
default=None,
help="Either 'demo' or a comma-separated list of tool server urls to use for the model. If not specified, no tool server will be used.",
)
# Deprecated arguments # Deprecated arguments
parser.add_argument( parser.add_argument(
"--enable-ep-moe", "--enable-ep-moe",
......
...@@ -41,6 +41,7 @@ import tempfile ...@@ -41,6 +41,7 @@ import tempfile
import threading import threading
import time import time
import traceback import traceback
import uuid
import warnings import warnings
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from contextlib import contextmanager from contextlib import contextmanager
...@@ -233,6 +234,10 @@ def is_flashinfer_available(): ...@@ -233,6 +234,10 @@ def is_flashinfer_available():
return importlib.util.find_spec("flashinfer") is not None and is_cuda() return importlib.util.find_spec("flashinfer") is not None and is_cuda()
def random_uuid() -> str:
return str(uuid.uuid4().hex)
_ENABLE_TORCH_INFERENCE_MODE = get_bool_env_var( _ENABLE_TORCH_INFERENCE_MODE = get_bool_env_var(
"SGLANG_ENABLE_TORCH_INFERENCE_MODE", "false" "SGLANG_ENABLE_TORCH_INFERENCE_MODE", "false"
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment