Unverified Commit 92cc32d9 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

Support v1/responses and use harmony in serving_chat (#8837)


Signed-off-by: default avatarXinyuan Tong <justinning0323@outlook.com>
Signed-off-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: default avatarXinyuan Tong <justinning0323@outlook.com>
Co-authored-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
parent cbbd685a
......@@ -29,6 +29,7 @@ runtime_common = [
"modelscope",
"msgspec",
"ninja",
"openai-harmony==0.0.3",
"orjson",
"outlines==0.1.11",
"packaging",
......@@ -96,7 +97,7 @@ srt_cpu = ["sglang[runtime_common]", "einops"]
# https://vllm-ascend.readthedocs.io/en/latest/installation.html
srt_npu = ["sglang[runtime_common]"]
openai = ["openai>=1.0", "tiktoken"]
openai = ["openai>=1.99.1", "tiktoken"]
anthropic = ["anthropic>=0.20.0"]
litellm = ["litellm>=1.0.0"]
torch_memory_saver = ["torch_memory_saver>=0.0.8"]
......
# SPDX-License-Identifier: Apache-2.0
# Copied from vLLM
import json
import logging
from abc import ABC, abstractmethod
from typing import Union
logger = logging.getLogger(__name__)
try:
from mcp import ClientSession
except ImportError:
logger.warning("Ignoring mcp import error")
from openai_harmony import Author, Message, Role, StreamState, TextContent
from sglang.srt.entrypoints.harmony_utils import (
get_encoding,
get_streamable_parser_for_assistant,
render_for_completion,
)
from sglang.srt.entrypoints.tool import Tool
class ConversationContext(ABC):
@abstractmethod
def append_output(self, output) -> None:
pass
@abstractmethod
async def call_tool(self) -> list[Message]:
pass
@abstractmethod
def need_builtin_tool_call(self) -> bool:
pass
@abstractmethod
def render_for_completion(self) -> list[int]:
pass
class SimpleContext(ConversationContext):
def __init__(self):
self.last_output = None
def append_output(self, output) -> None:
self.last_output = output
def need_builtin_tool_call(self) -> bool:
return False
async def call_tool(self) -> list[Message]:
raise NotImplementedError("Should not be called.")
def render_for_completion(self) -> list[int]:
raise NotImplementedError("Should not be called.")
class HarmonyContext(ConversationContext):
def __init__(
self,
messages: list,
tool_sessions: dict[str, Union["ClientSession", Tool]],
):
# TODO: Remove the hack of Union[ClientSession, Tool] by using MCP
# when demo.
self._messages = messages
self.tool_sessions = tool_sessions
self.parser = get_streamable_parser_for_assistant()
self.num_init_messages = len(messages)
# TODO
self.num_prompt_tokens = 0
self.num_cached_tokens = 0
self.num_output_tokens = 0
self.num_reasoning_tokens = 0
def append_output(self, output) -> None:
if isinstance(output, dict) and "output_ids" in output:
output_token_ids = output["output_ids"]
# TODO: REMOVE here:
# Very hacky, find the first occurrence of token 200006 and cut from there
try:
start_index = output_token_ids.index(200006)
output_token_ids = output_token_ids[start_index:]
except ValueError:
pass
for token_id in output_token_ids:
self.parser.process(token_id)
output_msgs = self.parser.messages
meta_info = output["meta_info"]
if isinstance(meta_info, dict):
if "prompt_token_ids" in meta_info:
self.num_prompt_tokens = meta_info["prompt_tokens"]
if "cached_tokens" in meta_info:
self.num_cached_tokens = meta_info["cached_tokens"]
if "completion_tokens" in meta_info:
self.num_output_tokens += meta_info["completion_tokens"]
else:
output_msgs = output
self._messages.extend(output_msgs)
@property
def messages(self) -> list:
return self._messages
def need_builtin_tool_call(self) -> bool:
last_msg = self.messages[-1]
recipient = last_msg.recipient
return recipient is not None and (
recipient.startswith("browser.") or recipient.startswith("python")
)
async def call_tool(self) -> list[Message]:
if not self.messages:
return []
last_msg = self.messages[-1]
recipient = last_msg.recipient
if recipient is not None:
if recipient.startswith("browser."):
return await self.call_search_tool(
self.tool_sessions["browser"], last_msg
)
elif recipient.startswith("python"):
return await self.call_python_tool(
self.tool_sessions["python"], last_msg
)
raise ValueError("No tool call found")
def render_for_completion(self) -> list[int]:
return render_for_completion(self.messages)
async def call_search_tool(
self, tool_session: Union["ClientSession", Tool], last_msg: Message
) -> list[Message]:
if isinstance(tool_session, Tool):
return await tool_session.get_result(self)
tool_name = last_msg.recipient.split(".")[1]
args = json.loads(last_msg.content[0].text)
result = await tool_session.call_tool(tool_name, args)
result_str = result.content[0].text
content = TextContent(text=result_str)
author = Author(role=Role.TOOL, name=last_msg.recipient)
return [Message(author=author, content=[content], recipient=Role.ASSISTANT)]
async def call_python_tool(
self, tool_session: Union["ClientSession", Tool], last_msg: Message
) -> list[Message]:
if isinstance(tool_session, Tool):
return await tool_session.get_result(self)
param = {
"code": last_msg.content[0].text,
}
result = await tool_session.call_tool("python", param)
result_str = result.content[0].text
content = TextContent(text=result_str)
author = Author(role=Role.TOOL, name="python")
return [
Message(
author=author,
content=[content],
channel=last_msg.channel,
recipient=Role.ASSISTANT,
)
]
class StreamingHarmonyContext(HarmonyContext):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.last_output = None
self.parser = get_streamable_parser_for_assistant()
self.encoding = get_encoding()
self.last_tok = None
@property
def messages(self) -> list:
return self.parser.messages
def append_output(self, output) -> None:
if isinstance(output, dict) and "output_ids" in output:
# RequestOutput from SGLang with outputs
output_token_ids = output["output_ids"]
# TODO: REMOVE here:
# Very hacky, find the first occurrence of token 200006 and cut from there
# Find the first occurrence of token 200006 and cut from there
try:
start_index = output_token_ids.index(200006)
output_token_ids = output_token_ids[start_index:]
except ValueError:
pass
for token_id in output_token_ids:
self.parser.process(token_id)
else:
# Handle the case of tool output in direct message format
assert len(output) == 1, "Tool output should be a single message"
msg = output[0]
# Sometimes the recipient is not set for tool messages,
# so we set it to "assistant"
if msg.author.role == Role.TOOL and msg.recipient is None:
msg.recipient = "assistant"
toks = self.encoding.render(msg)
for tok in toks:
self.parser.process(tok)
self.last_tok = toks[-1]
def is_expecting_start(self) -> bool:
return self.parser.state == StreamState.EXPECT_START
def is_assistant_action_turn(self) -> bool:
return self.last_tok in self.encoding.stop_tokens_for_assistant_actions()
def render_for_completion(self) -> list[int]:
# now this list of tokens as next turn's starting tokens
# `<|start|>assistant``,
# we need to process them in parser.
rendered_tokens = super().render_for_completion()
last_n = -1
to_process = []
while rendered_tokens[last_n] != self.last_tok:
to_process.append(rendered_tokens[last_n])
last_n -= 1
for tok in reversed(to_process):
self.parser.process(tok)
return rendered_tokens
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import datetime
import json
from collections.abc import Iterable
from typing import Literal, Optional, Union
from openai.types.responses import (
ResponseOutputItem,
ResponseOutputMessage,
ResponseOutputText,
ResponseReasoningItem,
)
from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall
from openai.types.responses.response_function_web_search import (
ActionFind,
ActionOpenPage,
ActionSearch,
ResponseFunctionWebSearch,
)
from openai.types.responses.response_reasoning_item import (
Content as ResponseReasoningTextContent,
)
from openai.types.responses.tool import Tool
from openai_harmony import (
Author,
Conversation,
DeveloperContent,
HarmonyEncodingName,
Message,
ReasoningEffort,
Role,
StreamableParser,
SystemContent,
TextContent,
ToolDescription,
load_harmony_encoding,
)
from sglang.srt.entrypoints.openai.protocol import ResponseInputOutputItem
from sglang.srt.utils import random_uuid
REASONING_EFFORT = {
"high": ReasoningEffort.HIGH,
"medium": ReasoningEffort.MEDIUM,
"low": ReasoningEffort.LOW,
}
_harmony_encoding = None
def get_encoding():
global _harmony_encoding
if _harmony_encoding is None:
_harmony_encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
return _harmony_encoding
def get_system_message(
model_identity: Optional[str] = None,
reasoning_effort: Optional[Literal["high", "medium", "low"]] = None,
start_date: Optional[str] = None,
browser_description: Optional[str] = None,
python_description: Optional[str] = None,
) -> Message:
sys_msg_content = SystemContent.new()
if model_identity is not None:
sys_msg_content = sys_msg_content.with_model_identity(model_identity)
if reasoning_effort is not None:
sys_msg_content = sys_msg_content.with_reasoning_effort(
REASONING_EFFORT[reasoning_effort]
)
if start_date is None:
start_date = datetime.datetime.now().strftime("%Y-%m-%d")
sys_msg_content = sys_msg_content.with_conversation_start_date(start_date)
if browser_description is not None:
sys_msg_content = sys_msg_content.with_tools(browser_description)
if python_description is not None:
sys_msg_content = sys_msg_content.with_tools(python_description)
sys_msg = Message.from_role_and_content(Role.SYSTEM, sys_msg_content)
return sys_msg
def get_developer_message(
instructions: Optional[str] = None, tools: Optional[list[Tool]] = None
) -> Message:
dev_msg_content = DeveloperContent.new()
if instructions is not None:
dev_msg_content = dev_msg_content.with_instructions(instructions)
if tools is not None:
function_tools = []
for tool in tools:
if tool.type in ("web_search_preview", "code_interpreter"):
# These are built-in tools that are added to the system message.
pass
elif tool.type == "function":
function_tools.append(tool)
else:
raise ValueError(f"tool type {tool.type} not supported")
if function_tools:
function_tool_descriptions = [
ToolDescription.new(
name=tool.name,
description=tool.description,
parameters=tool.parameters,
)
for tool in function_tools
]
dev_msg_content = dev_msg_content.with_function_tools(
function_tool_descriptions
)
dev_msg = Message.from_role_and_content(Role.DEVELOPER, dev_msg_content)
return dev_msg
def get_user_message(content: str) -> Message:
return Message.from_role_and_content(Role.USER, content)
def parse_response_input(
response_msg: ResponseInputOutputItem,
prev_responses: list[Union[ResponseOutputItem, ResponseReasoningItem]],
) -> Message:
if not isinstance(response_msg, dict):
response_msg = response_msg.model_dump()
if "type" not in response_msg or response_msg["type"] == "message":
role = response_msg["role"]
content = response_msg["content"]
if role == "system":
# User is trying to set a system message. Change it to:
# <|start|>developer<|message|># Instructions
# {instructions}<|end|>
role = "developer"
text_prefix = "Instructions:\n"
else:
text_prefix = ""
if isinstance(content, str):
msg = Message.from_role_and_content(role, text_prefix + content)
else:
contents = [TextContent(text=text_prefix + c["text"]) for c in content]
msg = Message.from_role_and_contents(role, contents)
elif response_msg["type"] == "function_call_output":
call_id = response_msg["call_id"]
call_response: Optional[ResponseFunctionToolCall] = None
for prev_response in reversed(prev_responses):
if (
isinstance(prev_response, ResponseFunctionToolCall)
and prev_response.call_id == call_id
):
call_response = prev_response
break
if call_response is None:
raise ValueError(f"No call message found for {call_id}")
msg = Message.from_author_and_content(
Author.new(Role.TOOL, f"functions.{call_response.name}"),
response_msg["output"],
)
elif response_msg["type"] == "reasoning":
content = response_msg["content"]
assert len(content) == 1
msg = Message.from_role_and_content(Role.ASSISTANT, content[0]["text"])
elif response_msg["type"] == "function_call":
msg = Message.from_role_and_content(Role.ASSISTANT, response_msg["arguments"])
msg = msg.with_channel("commentary")
msg = msg.with_recipient(f"functions.{response_msg['name']}")
msg = msg.with_content_type("json")
else:
raise ValueError(f"Unknown input type: {response_msg['type']}")
return msg
def parse_response_output(output: ResponseOutputItem) -> Message:
if isinstance(output, ResponseOutputMessage):
role = output.role
contents = [TextContent(text=c.text) for c in output.content]
msg = Message.from_role_and_contents(role, contents)
return msg
elif isinstance(output, ResponseFunctionToolCall):
msg = Message.from_role_and_content(Role.ASSISTANT, output.arguments)
msg = msg.with_channel("commentary")
msg = msg.with_recipient(output.name)
msg = msg.with_content_type("json")
return msg
else:
raise ValueError(f"Unknown output type: {type(output)}")
def parse_chat_input(chat_msg) -> Message:
role = chat_msg.role
content = chat_msg.content
if isinstance(content, str):
contents = [TextContent(text=content)]
else:
# TODO: Support refusal.
contents = [TextContent(text=c.text) for c in content]
msg = Message.from_role_and_contents(role, contents)
return msg
def render_for_completion(messages: list[Message]) -> list[int]:
conversation = Conversation.from_messages(messages)
token_ids = get_encoding().render_conversation_for_completion(
conversation, Role.ASSISTANT
)
return token_ids
def get_stop_tokens_for_assistant_actions() -> list[int]:
return get_encoding().stop_tokens_for_assistant_actions()
def get_streamable_parser_for_assistant() -> StreamableParser:
return StreamableParser(get_encoding(), role=Role.ASSISTANT)
def parse_output_message(message: Message):
if message.author.role != "assistant":
# This is a message from a tool to the assistant (e.g., search result).
# Don't include it in the final output for now. This aligns with
# OpenAI's behavior on models like o4-mini.
return []
output_items = []
recipient = message.recipient
if recipient is not None and recipient.startswith("browser."):
if len(message.content) != 1:
raise ValueError("Invalid number of contents in browser message")
content = message.content[0]
browser_call = json.loads(content.text)
# TODO: translate to url properly!
if recipient == "browser.search":
action = ActionSearch(
query=f"cursor:{browser_call.get('query', '')}", type="search"
)
elif recipient == "browser.open":
action = ActionOpenPage(
url=f"cursor:{browser_call.get('url', '')}", type="open_page"
)
elif recipient == "browser.find":
action = ActionFind(
pattern=browser_call["pattern"],
url=f"cursor:{browser_call.get('url', '')}",
type="find",
)
else:
raise ValueError(f"Unknown browser action: {recipient}")
web_search_item = ResponseFunctionWebSearch(
id=f"ws_{random_uuid()}",
action=action,
status="completed",
type="web_search_call",
)
output_items.append(web_search_item)
elif message.channel == "analysis":
for content in message.content:
reasoning_item = ResponseReasoningItem(
id=f"rs_{random_uuid()}",
type="reasoning",
summary=[],
content=[
ResponseReasoningTextContent(
text=content.text, type="reasoning_text"
)
],
status=None,
)
output_items.append(reasoning_item)
elif message.channel == "commentary":
if message.recipient.startswith("functions."):
function_name = message.recipient.split(".")[-1]
for content in message.content:
random_id = random_uuid()
response_item = ResponseFunctionToolCall(
arguments=content.text,
call_id=f"call_{random_id}",
type="function_call",
name=function_name,
id=f"ft_{random_id}",
)
output_items.append(response_item)
elif message.recipient.startswith("python") or message.recipient.startswith(
"browser"
):
for content in message.content:
reasoning_item = ResponseReasoningItem(
id=f"rs_{random_uuid()}",
type="reasoning",
summary=[],
content=[
ResponseReasoningTextContent(
text=content.text, type="reasoning_text"
)
],
status=None,
)
output_items.append(reasoning_item)
else:
raise ValueError(f"Unknown recipient: {message.recipient}")
elif message.channel == "final":
contents = []
for content in message.content:
output_text = ResponseOutputText(
text=content.text,
annotations=[], # TODO
type="output_text",
logprobs=None, # TODO
)
contents.append(output_text)
text_item = ResponseOutputMessage(
id=f"msg_{random_uuid()}",
content=contents,
role=message.author.role,
status="completed",
type="message",
)
output_items.append(text_item)
else:
raise ValueError(f"Unknown channel: {message.channel}")
return output_items
def parse_remaining_state(parser: StreamableParser):
if not parser.current_content:
return []
if parser.current_role != Role.ASSISTANT:
return []
current_recipient = parser.current_recipient
if current_recipient is not None and current_recipient.startswith("browser."):
return []
if parser.current_channel == "analysis":
reasoning_item = ResponseReasoningItem(
id=f"rs_{random_uuid()}",
type="reasoning",
summary=[],
content=[
ResponseReasoningTextContent(
text=parser.current_content, type="reasoning_text"
)
],
status=None,
)
return [reasoning_item]
elif parser.current_channel == "final":
output_text = ResponseOutputText(
content=[
ResponseReasoningTextContent(
text=parser.current_content, type="reasoning_text"
)
],
annotations=[], # TODO
type="output_text",
logprobs=None, # TODO
)
text_item = ResponseOutputMessage(
id=f"msg_{random_uuid()}",
content=[output_text],
role="assistant",
status="completed",
type="message",
)
return [text_item]
return []
def parse_output_into_messages(token_ids: Iterable[int]):
parser = get_streamable_parser_for_assistant()
for token_id in token_ids:
parser.process(token_id)
return parser
......@@ -32,6 +32,7 @@ from typing import AsyncIterator, Callable, Dict, Optional
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
from contextlib import asynccontextmanager
from typing import AsyncGenerator
import numpy as np
import orjson
......@@ -56,6 +57,7 @@ from sglang.srt.entrypoints.openai.protocol import (
ErrorResponse,
ModelCard,
ModelList,
ResponsesRequest,
ScoringRequest,
V1RerankReqInput,
)
......@@ -147,6 +149,37 @@ async def lifespan(fast_api_app: FastAPI):
)
server_args: ServerArgs = fast_api_app.server_args
tool_server = None
if server_args.tool_server == "demo":
from sglang.srt.entrypoints.openai.tool_server import DemoToolServer
tool_server = DemoToolServer()
elif server_args.tool_server:
from sglang.srt.entrypoints.openai.tool_server import MCPToolServer
tool_server = MCPToolServer()
await tool_server.add_tool_server(server_args.tool_server)
try:
from sglang.srt.entrypoints.openai.serving_responses import (
OpenAIServingResponses,
)
fast_api_app.state.openai_serving_responses = OpenAIServingResponses(
_global_state.tokenizer_manager,
_global_state.template_manager,
enable_prompt_tokens_details=True,
enable_force_include_usage=True,
tool_server=tool_server,
)
except Exception as e:
# print stack trace
import traceback
traceback.print_exc()
logger.warning(f"Can not initialize OpenAIServingResponses, error: {e}")
if server_args.warmups is not None:
await execute_warmups(
server_args.disaggregation_mode,
......@@ -843,6 +876,42 @@ async def v1_score_request(request: ScoringRequest, raw_request: Request):
)
@app.post("/v1/responses", dependencies=[Depends(validate_json_request)])
async def v1_responses_request(request: dict, raw_request: Request):
"""Endpoint for the responses API with reasoning support."""
request_obj = ResponsesRequest(**request)
result = await raw_request.app.state.openai_serving_responses.create_responses(
request_obj, raw_request
)
# Handle streaming responses
if isinstance(result, AsyncGenerator):
return StreamingResponse(
result,
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)
return result
@app.get("/v1/responses/{response_id}")
async def v1_retrieve_responses(response_id: str, raw_request: Request):
"""Retrieve a response by ID."""
return await raw_request.app.state.openai_serving_responses.retrieve_responses(
response_id
)
@app.post("/v1/responses/{response_id}/cancel")
async def v1_cancel_responses(response_id: str, raw_request: Request):
"""Cancel a background response."""
return await raw_request.app.state.openai_serving_responses.cancel_responses(
response_id
)
@app.api_route(
"/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
)
......
......@@ -14,9 +14,18 @@
"""Pydantic models for OpenAI API protocol"""
import time
import uuid
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, TypeAlias, Union
from openai.types.responses import (
ResponseFunctionToolCall,
ResponseInputItemParam,
ResponseOutputItem,
ResponseReasoningItem,
)
from openai.types.responses.response import ToolChoice
from openai.types.responses.tool import Tool
from pydantic import (
BaseModel,
Field,
......@@ -84,6 +93,7 @@ class UsageInfo(BaseModel):
completion_tokens: Optional[int] = 0
# only used to return cached tokens when --enable-cache-report is set
prompt_tokens_details: Optional[Dict[str, int]] = None
reasoning_tokens: Optional[int] = 0
class StreamOptions(BaseModel):
......@@ -428,6 +438,13 @@ class ChatCompletionRequest(BaseModel):
default="auto", examples=["none"]
) # noqa
return_hidden_states: bool = False
reasoning_effort: Optional[Literal["low", "medium", "high"]] = Field(
default="medium",
description="Constrains effort on reasoning for reasoning models. "
"'low' is the least effort, 'high' is the most effort. Reducing reasoning effort can "
"result in faster responses and fewer tokens used on reasoning in a response. "
"Currently only supported for OpenAI models.",
)
@model_validator(mode="before")
@classmethod
......@@ -619,6 +636,196 @@ OpenAIServingRequest = Union[
]
# Response API protocol definitions
class ResponseReasoningParam(BaseModel):
"""Reasoning parameters for responses."""
effort: Optional[Literal["low", "medium", "high"]] = Field(
default="medium",
description="Constrains effort on reasoning for reasoning models.",
)
class ResponseTool(BaseModel):
"""Tool definition for responses."""
type: Literal["web_search_preview", "code_interpreter"] = Field(
description="Type of tool to enable"
)
ResponseInputOutputItem: TypeAlias = Union[
ResponseInputItemParam,
"ResponseReasoningItem",
ResponseFunctionToolCall,
]
class ResponsesRequest(BaseModel):
"""Request body for v1/responses endpoint."""
# Core OpenAI API fields (ordered by official documentation)
background: Optional[bool] = False
include: Optional[
List[
Literal[
"code_interpreter_call.outputs",
"computer_call_output.output.image_url",
"file_search_call.results",
"message.input_image.image_url",
"message.output_text.logprobs",
"reasoning.encrypted_content",
]
]
] = None
input: Union[str, List[ResponseInputOutputItem]]
instructions: Optional[str] = None
max_output_tokens: Optional[int] = None
max_tool_calls: Optional[int] = None
metadata: Optional[Dict[str, Any]] = None
model: Optional[str] = None # Made optional to match vLLM
parallel_tool_calls: Optional[bool] = True
previous_response_id: Optional[str] = None
reasoning: Optional[ResponseReasoningParam] = None
service_tier: Literal["auto", "default", "flex", "scale", "priority"] = "auto"
store: Optional[bool] = True
stream: Optional[bool] = False
temperature: Optional[float] = None
tool_choice: Literal["auto", "required", "none"] = "auto"
tools: List[ResponseTool] = Field(default_factory=list)
top_logprobs: Optional[int] = 0
top_p: Optional[float] = None
truncation: Optional[Literal["auto", "disabled"]] = "disabled"
user: Optional[str] = None
# Extra SGLang parameters
request_id: str = Field(
default_factory=lambda: f"resp_{uuid.uuid4().hex}",
description="The request_id related to this request. If the caller does not set it, a random uuid will be generated.",
)
priority: int = Field(default=0, description="Request priority")
# SGLang-specific sampling parameters
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
stop: Optional[Union[str, List[str]]] = None
top_k: int = -1
min_p: float = 0.0
repetition_penalty: float = 1.0
# Default sampling parameters
_DEFAULT_SAMPLING_PARAMS = {
"temperature": 0.7,
"top_p": 1.0,
"top_k": -1,
"min_p": 0.0,
"repetition_penalty": 1.0,
}
def to_sampling_params(
self, default_max_tokens: int, default_params: Optional[Dict] = None
) -> Dict[str, Any]:
"""Convert to sampling parameters for generation."""
if default_params is None:
default_params = {}
# Use max_output_tokens if available, otherwise use max_tokens for backwards compatibility
if self.max_output_tokens is not None:
max_tokens = min(self.max_output_tokens, default_max_tokens)
else:
max_tokens = default_max_tokens
# Avoid exceed the context length by minus 1 token
max_tokens -= 1
# Get parameters with defaults
temperature = self.temperature
if temperature is None:
temperature = default_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
)
top_p = self.top_p
if top_p is None:
top_p = default_params.get("top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
params = {
"max_new_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"stop": self.stop,
"top_k": self.top_k,
"min_p": self.min_p,
"repetition_penalty": self.repetition_penalty,
}
# Apply any additional default parameters
for key, value in default_params.items():
if key not in params or params[key] is None:
params[key] = value
return params
class PromptTokenUsageInfo(BaseModel):
"""Prompt token usage details."""
cached_tokens: int = 0
class ResponsesResponse(BaseModel):
"""Response body for v1/responses endpoint."""
id: str = Field(default_factory=lambda: f"resp_{time.time()}")
object: Literal["response"] = "response"
created_at: int = Field(default_factory=lambda: int(time.time()))
model: str
output: List[
Union[ResponseOutputItem, ResponseReasoningItem, ResponseFunctionToolCall]
] = Field(default_factory=list)
status: Literal["queued", "in_progress", "completed", "failed", "cancelled"]
usage: Optional[UsageInfo] = None
parallel_tool_calls: bool = True
tool_choice: str = "auto"
tools: List[ResponseTool] = Field(default_factory=list)
@classmethod
def from_request(
cls,
request: ResponsesRequest,
sampling_params: Any,
model_name: str,
created_time: int,
output: List[
Union[ResponseOutputItem, ResponseReasoningItem, ResponseFunctionToolCall]
],
status: str,
usage: Optional[UsageInfo],
) -> "ResponsesResponse":
"""Create a response from a request."""
return cls(
id=request.request_id,
created_at=created_time,
model=model_name,
output=output,
status=status,
usage=usage,
parallel_tool_calls=request.parallel_tool_calls or True,
tool_choice=request.tool_choice,
tools=request.tools,
)
class RequestResponseMetadata(BaseModel):
"""Metadata for request/response tracking."""
request_id: str
final_usage_info: Optional[UsageInfo] = None
@dataclass
class MessageProcessingResult:
"""Result of processing chat messages and applying templates.
......@@ -645,3 +852,22 @@ class MessageProcessingResult:
modalities: List[str]
stop: List[str]
tool_call_constraint: Optional[Any] = None
class ResponseReasoningTextContent(BaseModel):
text: str
type: Literal["reasoning_text"] = "reasoning_text"
class ResponseReasoningItem(BaseModel):
id: str
content: list[ResponseReasoningTextContent] = Field(default_factory=list)
summary: list = Field(default_factory=list)
type: Literal["reasoning"] = "reasoning"
encrypted_content: Optional[str] = None
status: Optional[Literal["in_progress", "completed", "incomplete"]]
ResponseInputOutputItem: TypeAlias = Union[
ResponseInputItemParam, "ResponseReasoningItem", ResponseFunctionToolCall
]
......@@ -7,8 +7,18 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from fastapi import Request
from fastapi.responses import ORJSONResponse, StreamingResponse
from openai_harmony import Message as OpenAIMessage
from sglang.srt.conversation import generate_chat_conv
from sglang.srt.entrypoints.harmony_utils import (
get_developer_message,
get_stop_tokens_for_assistant_actions,
get_streamable_parser_for_assistant,
get_system_message,
parse_chat_input,
parse_output_into_messages,
render_for_completion,
)
from sglang.srt.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
......@@ -51,6 +61,26 @@ class OpenAIServingChat(OpenAIServingBase):
):
super().__init__(tokenizer_manager)
self.template_manager = template_manager
self.use_harmony = (
self.tokenizer_manager.model_config.hf_config.model_type == "gpt_oss"
)
if self.use_harmony:
from sglang.srt.function_call.harmony_tool_parser import (
HarmonyToolCallParser,
)
self.harmony_tool_parser = HarmonyToolCallParser()
# NOTE While OpenAI's chat completion API supports browsing
# for some models, currently vLLM doesn't support it. Please use the
# Responses API instead.
self.supports_browsing = False
self.browser_tool = None
# NOTE: Chat completion API does not support code interpreter.
# Please use the Responses API instead.
self.supports_code_interpreter = False
self.python_tool = None
def _request_id_prefix(self) -> str:
return "chatcmpl-"
......@@ -77,41 +107,66 @@ class OpenAIServingChat(OpenAIServingBase):
is_multimodal = self.tokenizer_manager.model_config.is_multimodal
# Process messages and apply chat template
processed_messages = self._process_messages(request, is_multimodal)
# Build sampling parameters
sampling_params = self._build_sampling_params(
request, processed_messages.stop, processed_messages.tool_call_constraint
)
if not self.use_harmony:
processed_messages = self._process_messages(request, is_multimodal)
# Build sampling parameters
sampling_params = self._build_sampling_params(
request,
processed_messages.stop,
processed_messages.tool_call_constraint,
)
# Handle single vs multiple requests
if is_multimodal:
prompt_kwargs = {"text": processed_messages.prompt}
else:
if isinstance(processed_messages.prompt_ids, str):
prompt_kwargs = {"text": processed_messages.prompt_ids}
# Handle single vs multiple requests
if is_multimodal:
prompt_kwargs = {"text": processed_messages.prompt}
else:
prompt_kwargs = {"input_ids": processed_messages.prompt_ids}
adapted_request = GenerateReqInput(
**prompt_kwargs,
image_data=processed_messages.image_data,
video_data=processed_messages.video_data,
audio_data=processed_messages.audio_data,
sampling_params=sampling_params,
return_logprob=request.logprobs,
logprob_start_len=-1,
top_logprobs_num=request.top_logprobs or 0,
stream=request.stream,
return_text_in_logprobs=True,
modalities=processed_messages.modalities,
lora_path=request.lora_path,
bootstrap_host=request.bootstrap_host,
bootstrap_port=request.bootstrap_port,
bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states,
rid=request.rid,
)
if isinstance(processed_messages.prompt_ids, str):
prompt_kwargs = {"text": processed_messages.prompt_ids}
else:
prompt_kwargs = {"input_ids": processed_messages.prompt_ids}
adapted_request = GenerateReqInput(
**prompt_kwargs,
image_data=processed_messages.image_data,
video_data=processed_messages.video_data,
audio_data=processed_messages.audio_data,
sampling_params=sampling_params,
return_logprob=request.logprobs,
logprob_start_len=-1,
top_logprobs_num=request.top_logprobs or 0,
stream=request.stream,
return_text_in_logprobs=True,
modalities=processed_messages.modalities,
lora_path=request.lora_path,
bootstrap_host=request.bootstrap_host,
bootstrap_port=request.bootstrap_port,
bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states,
rid=request.rid,
)
else:
processed_messages, prompt_ids = self._make_request_with_harmony(request)
adapted_request = GenerateReqInput(
input_ids=prompt_ids,
sampling_params=self._build_sampling_params(
request,
request.stop,
tool_call_constraint=None,
),
stream=request.stream,
return_logprob=request.logprobs,
logprob_start_len=-1,
top_logprobs_num=request.top_logprobs or 0,
return_text_in_logprobs=True,
lora_path=request.lora_path,
bootstrap_host=request.bootstrap_host,
bootstrap_port=request.bootstrap_port,
bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states,
rid=request.rid,
)
return adapted_request, request
......@@ -402,6 +457,12 @@ class OpenAIServingChat(OpenAIServingBase):
cached_tokens = {}
hidden_states = {}
# Harmony tracking
if self.use_harmony:
harmony_parsers = [
get_streamable_parser_for_assistant() for _ in range(request.n)
]
try:
async for content in self.tokenizer_manager.generate_request(
adapted_request, raw_request
......@@ -449,14 +510,57 @@ class OpenAIServingChat(OpenAIServingBase):
yield f"data: {chunk.model_dump_json()}\n\n"
# Process content delta
stream_buffer = stream_buffers.get(index, "")
delta = content["text"][len(stream_buffer) :]
stream_buffers[index] = stream_buffer + delta
if self.use_harmony:
harmony_parser = harmony_parsers[index]
new_token_ids = content["output_ids"]
for token_id in new_token_ids:
harmony_parser.process(token_id)
is_final = harmony_parser.current_channel == "final"
is_analysis = harmony_parser.current_channel == "analysis"
delta = harmony_parser.last_content_delta or ""
if is_analysis:
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(reasoning_content=delta),
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=int(time.time()),
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n"
continue
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(content=delta if delta else None),
finish_reason=None,
matched_stop=None,
logprobs=choice_logprobs,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=int(time.time()),
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n"
continue
else:
stream_buffer = stream_buffers.get(index, "")
delta = content["text"][len(stream_buffer) :]
stream_buffers[index] = stream_buffer + delta
# Handle reasoning content
if (
self.tokenizer_manager.server_args.reasoning_parser
and request.separate_reasoning
and not self.use_harmony
):
reasoning_text, delta = self._process_reasoning_stream(
index, delta, reasoning_parser_dict, content, request
......@@ -475,8 +579,27 @@ class OpenAIServingChat(OpenAIServingBase):
)
yield f"data: {chunk.model_dump_json()}\n\n"
if self.use_harmony and not is_final:
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(reasoning_content=delta),
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=int(time.time()),
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n"
# Handle tool calls
if request.tool_choice != "none" and request.tools:
# TODO: support tool call parsing for harmony
if (
request.tool_choice != "none"
and request.tools
and not self.use_harmony
):
async for chunk in self._process_tool_call_stream(
index,
delta,
......@@ -502,7 +625,7 @@ class OpenAIServingChat(OpenAIServingBase):
if delta:
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(content=delta if delta else None),
delta=DeltaMessage(content=delta),
finish_reason=None,
matched_stop=None,
logprobs=choice_logprobs,
......@@ -640,6 +763,76 @@ class OpenAIServingChat(OpenAIServingBase):
finish_reason = ret_item["meta_info"]["finish_reason"]
text = ret_item["text"]
output_ids = ret_item["output_ids"]
if self.use_harmony:
parser = parse_output_into_messages(output_ids)
output_msgs = parser.messages
if len(output_msgs) == 0:
# The generation has stopped during reasoning.
is_tool_call = False
reasoning_content = parser.current_content
final_content = None
elif len(output_msgs) == 1:
# The generation has stopped during final message.
is_tool_call = False
reasoning_content = output_msgs[0].content[0].text
final_content = parser.current_content
else:
if len(output_msgs) != 2:
raise ValueError(
"Expected 2 output messages (reasoning and final), "
f"but got {len(output_msgs)}."
)
reasoning_msg, final_msg = output_msgs
reasoning_content = reasoning_msg.content[0].text
final_content = final_msg.content[0].text
is_tool_call = final_msg.recipient is not None
if is_tool_call:
# Extract tool call information from final message
tool_call = (
self.harmony_tool_parser.extract_tool_calls_from_message(
final_msg
)
)
tool_calls = [tool_call] if tool_call else []
message = ChatMessage(
role="assistant",
reasoning_content=reasoning_content,
content=None, # Tool calls don't have regular content
tool_calls=tool_calls,
)
else:
# Normal message
message = ChatMessage(
role="assistant",
reasoning_content=reasoning_content,
content=final_content,
)
if is_tool_call:
finish_reason_type = "tool_calls"
elif finish_reason:
finish_reason_type = (
finish_reason["type"] if finish_reason else "stop"
)
else:
finish_reason_type = "stop"
choice_data = ChatCompletionResponseChoice(
index=idx,
message=message,
logprobs=choice_logprobs,
finish_reason=finish_reason_type,
matched_stop=(
finish_reason["matched"]
if finish_reason and "matched" in finish_reason
else None
),
)
choices.append(choice_data)
continue
# Handle reasoning content
reasoning_text = None
......@@ -978,3 +1171,33 @@ class OpenAIServingChat(OpenAIServingBase):
return f"data: {chunk.model_dump_json()}\n\n"
return None
def _make_request_with_harmony(
self,
request: ChatCompletionRequest,
):
messages: list[OpenAIMessage] = []
# Add system message.
# In Chat Completion API, browsing is enabled by default if the model
# supports it.
assert not self.supports_browsing
assert not self.supports_code_interpreter
sys_msg = get_system_message(
reasoning_effort=request.reasoning_effort,
browser_description=None,
python_description=None,
)
messages.append(sys_msg)
# Add developer message.
dev_msg = get_developer_message()
messages.append(dev_msg)
# Add user message.
for chat_msg in request.messages:
messages.append(parse_chat_input(chat_msg))
# Render prompt token ids.
prompt_token_ids = render_for_completion(messages)
return messages, prompt_token_ids
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
from abc import ABC, abstractmethod
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from typing import Any
logger = logging.getLogger(__name__)
try:
from mcp import ClientSession
from mcp.client.sse import sse_client
from mcp.types import ListToolsResult
except ImportError:
logger.warning("Ignoring mcp import error")
from openai_harmony import ToolDescription, ToolNamespaceConfig
async def list_server_and_tools(server_url: str):
async with sse_client(url=server_url) as streams, ClientSession(
*streams
) as session:
initialize_response = await session.initialize()
list_tools_response = await session.list_tools()
return initialize_response, list_tools_response
def trim_schema(schema: dict) -> dict:
# Turn JSON Schema from MCP generated into Harmony's variant.
if "title" in schema:
del schema["title"]
if "default" in schema and schema["default"] is None:
del schema["default"]
if "anyOf" in schema:
# Turn "anyOf": [{"type": "type-1"}, {"type": "type-2"}]
# into "type": ["type-1", "type-2"]
# if there's more than 1 types, also remove "null" type as Harmony will
# just ignore it
types = [
type_dict["type"]
for type_dict in schema["anyOf"]
if type_dict["type"] != "null"
]
schema["type"] = types
del schema["anyOf"]
if "properties" in schema:
schema["properties"] = {
k: trim_schema(v) for k, v in schema["properties"].items()
}
return schema
def post_process_tools_description(
list_tools_result: "ListToolsResult",
) -> "ListToolsResult":
# Adapt the MCP tool result for Harmony
for tool in list_tools_result.tools:
tool.inputSchema = trim_schema(tool.inputSchema)
# Some tools schema don't need to be part of the prompt (e.g. simple text
# in text out for Python)
list_tools_result.tools = [
tool
for tool in list_tools_result.tools
if getattr(tool.annotations, "include_in_prompt", True)
]
return list_tools_result
class ToolServer(ABC):
@abstractmethod
def has_tool(self, tool_name: str):
pass
@abstractmethod
def get_tool_description(self, tool_name: str):
pass
@abstractmethod
def get_tool_session(self, tool_name: str) -> AbstractAsyncContextManager[Any]: ...
class MCPToolServer(ToolServer):
def __init__(self):
self.harmony_tool_descriptions = {}
async def add_tool_server(self, server_url: str):
tool_urls = server_url.split(",")
self.harmony_tool_descriptions = {}
self.urls: dict[str, str] = {}
for url in tool_urls:
url = f"http://{url}/sse"
initialize_response, list_tools_response = await list_server_and_tools(url)
list_tools_response = post_process_tools_description(list_tools_response)
tool_from_mcp = ToolNamespaceConfig(
name=initialize_response.serverInfo.name,
description=initialize_response.instructions,
tools=[
ToolDescription.new(
name=tool.name,
description=tool.description,
parameters=tool.inputSchema,
)
for tool in list_tools_response.tools
],
)
self.harmony_tool_descriptions[tool_from_mcp.name] = tool_from_mcp
if tool_from_mcp.name not in self.urls:
self.urls[tool_from_mcp.name] = url
else:
logger.warning(
"Tool %s already exists. Ignoring duplicate tool server %s",
tool_from_mcp.name,
url,
)
def has_tool(self, tool_name: str):
return tool_name in self.harmony_tool_descriptions
def get_tool_description(self, tool_name: str):
return self.harmony_tool_descriptions.get(tool_name)
@asynccontextmanager
async def get_tool_session(self, tool_name: str):
url = self.urls.get(tool_name)
if url:
async with sse_client(url=url) as streams, ClientSession(
*streams
) as session:
await session.initialize()
yield session
else:
logger.warning("Tool %s not found", tool_name)
class DemoToolServer(ToolServer):
def __init__(self):
from sglang.srt.entrypoints.tool import (
HarmonyBrowserTool,
HarmonyPythonTool,
Tool,
)
self.tools: dict[str, Tool] = {}
browser_tool = HarmonyBrowserTool()
if browser_tool.enabled:
self.tools["browser"] = browser_tool
python_tool = HarmonyPythonTool()
if python_tool.enabled:
self.tools["python"] = python_tool
def has_tool(self, tool_name: str):
return tool_name in self.tools
def get_tool_description(self, tool_name: str):
if tool_name not in self.tools:
return None
if tool_name == "browser":
return ToolNamespaceConfig.browser()
elif tool_name == "python":
return ToolNamespaceConfig.python()
else:
raise ValueError(f"Unknown tool {tool_name}")
@asynccontextmanager
async def get_tool_session(self, tool_name: str):
yield self.tools[tool_name]
# SPDX-License-Identifier: Apache-2.0
import logging
import os
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
# Avoid circular import.
from sglang.srt.entrypoints.context import ConversationContext
logger = logging.getLogger(__name__)
class Tool(ABC):
@abstractmethod
async def get_result(self, context: "ConversationContext") -> Any:
pass
class HarmonyBrowserTool(Tool):
def __init__(self):
self.enabled = True
exa_api_key = os.getenv("EXA_API_KEY")
if not exa_api_key:
self.enabled = False
logger.warning_once("EXA_API_KEY is not set, browsing is disabled")
return
try:
from gpt_oss.tools.simple_browser import SimpleBrowserTool
from gpt_oss.tools.simple_browser.backend import ExaBackend
except ImportError:
self.enabled = False
logger.warning_once("gpt_oss is not installed, browsing is disabled")
return
browser_backend = ExaBackend(source="web", api_key=exa_api_key)
self.browser_tool = SimpleBrowserTool(backend=browser_backend)
logger.info_once("Browser tool initialized")
async def get_result(self, context: "ConversationContext") -> Any:
from sglang.srt.entrypoints.context import HarmonyContext
assert isinstance(context, HarmonyContext)
last_msg = context.messages[-1]
tool_output_msgs = []
async for msg in self.browser_tool.process(last_msg):
tool_output_msgs.append(msg)
return tool_output_msgs
@property
def tool_config(self) -> Any:
return self.browser_tool.tool_config
class HarmonyPythonTool(Tool):
def __init__(self):
self.enabled = True
try:
from gpt_oss.tools.python_docker.docker_tool import PythonTool
except ImportError:
self.enabled = False
logger.warning_once(
"gpt_oss is not installed, code interpreter is disabled"
)
return
self.python_tool = PythonTool()
logger.info_once("Code interpreter tool initialized")
async def get_result(self, context: "ConversationContext") -> Any:
from sglang.srt.entrypoints.context import HarmonyContext
assert isinstance(context, HarmonyContext)
last_msg = context.messages[-1]
tool_output_msgs = []
async for msg in self.python_tool.process(last_msg):
tool_output_msgs.append(msg)
return tool_output_msgs
@property
def tool_config(self) -> Any:
return self.python_tool.tool_config
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Harmony tool call parser for processing tool calls in harmony models."""
import uuid
from typing import List, Optional, Tuple
from sglang.srt.entrypoints.openai.protocol import (
ChatMessage,
FunctionResponse,
ToolCall,
)
class HarmonyToolCallParser:
"""Parser for extracting tool calls from harmony model outputs."""
def extract_tool_calls_from_message(self, msg) -> Optional[ToolCall]:
"""
Extract tool call from a single message if it's a tool call.
Args:
msg: The harmony message
Returns:
ToolCall if the message is a tool call, None otherwise
"""
if (
msg.channel == "commentary"
and msg.recipient
and msg.recipient.startswith("functions.")
):
function_name = msg.recipient.split(".")[-1]
arguments = msg.content[0].text if msg.content else "{}"
return ToolCall(
id=f"call_{uuid.uuid4().hex[:24]}",
function=FunctionResponse(
name=function_name,
arguments=arguments,
),
)
return None
def process_streaming_chunk(
self,
harmony_parser,
index: int,
tool_call_trackers: dict,
stream_buffers: dict,
) -> Tuple[Optional[dict], bool, Optional[str]]:
"""
Process a streaming chunk for tool calls.
Args:
harmony_parser: The harmony parser instance
index: The choice index
tool_call_trackers: Dict tracking tool calls per choice
stream_buffers: Dict for buffering content
Returns:
Tuple of (tool_call_data, is_tool_call, delta)
"""
# Check if we're in a tool call
is_tool_call = (
harmony_parser.current_channel == "commentary"
and harmony_parser.current_recipient
and harmony_parser.current_recipient.startswith("functions.")
)
delta = harmony_parser.last_content_delta or ""
tool_call_data = None
if is_tool_call:
# Handle tool call streaming
function_name = harmony_parser.current_recipient.split(".")[-1]
# Track tool call indices per choice
if index not in tool_call_trackers:
tool_call_trackers[index] = {"count": 0, "current_function": None}
# Check if we just started a new tool call
tool_call_tracker = tool_call_trackers[index]
if tool_call_tracker["current_function"] != function_name:
# New tool call started
tool_call_tracker["current_function"] = function_name
tool_call_index = tool_call_tracker["count"]
tool_call_tracker["count"] += 1
# Store the tool call index for this function
tool_call_key = f"{index}_{function_name}"
stream_buffers[tool_call_key] = {
"index": tool_call_index,
"content": "",
}
tool_call_data = {
"id": f"call_{uuid.uuid4().hex[:24]}",
"index": tool_call_index,
"function_name": function_name,
"arguments": delta,
"is_first_chunk": True,
}
else:
# Subsequent chunks for the same tool call
tool_call_key = f"{index}_{function_name}"
tool_call_index = stream_buffers[tool_call_key]["index"]
tool_call_data = {
"id": None,
"index": tool_call_index,
"function_name": None,
"arguments": delta,
"is_first_chunk": False,
}
stream_buffers[tool_call_key]["content"] += delta
return tool_call_data, is_tool_call, delta
......@@ -216,7 +216,7 @@ class DetokenizerManager:
rids=recv_obj.rids,
finished_reasons=recv_obj.finished_reasons,
output_strs=output_strs,
output_ids=None,
output_ids=recv_obj.decode_ids,
prompt_tokens=recv_obj.prompt_tokens,
completion_tokens=recv_obj.completion_tokens,
cached_tokens=recv_obj.cached_tokens,
......
......@@ -126,6 +126,9 @@ class GenerateReqInput:
# For data parallel rank routing
data_parallel_rank: Optional[int] = None
# For background responses (OpenAI responses API)
background: bool = False
def contains_mm_input(self) -> bool:
return (
has_valid_data(self.image_data)
......@@ -560,6 +563,9 @@ class EmbeddingReqInput:
# For cross-encoder requests
is_cross_encoder_request: bool = False
# For background responses (OpenAI responses API)
background: bool = False
def normalize_batch_and_arguments(self):
# at least one of text, input_ids, or image should be provided
if self.text is None and self.input_ids is None and self.image_data is None:
......
......@@ -571,8 +571,7 @@ class SchedulerOutputProcessorMixin:
req.send_decode_id_offset = len(decode_ids)
read_offsets.append(read_offset)
if self.skip_tokenizer_init:
output_ids.append(req.output_ids[send_token_offset:])
output_ids.append(req.output_ids[send_token_offset:])
req.send_token_offset = len(req.output_ids)
skip_special_tokens.append(req.sampling_params.skip_special_tokens)
spaces_between_special_tokens.append(
......
......@@ -750,7 +750,11 @@ class TokenizerManager:
try:
await asyncio.wait_for(state.event.wait(), timeout=4)
except asyncio.TimeoutError:
if request is not None and await request.is_disconnected():
if (
request is not None
and not obj.background
and await request.is_disconnected()
):
# Abort the request for disconnected requests (non-streaming, waiting queue)
self.abort_request(obj.rid)
# Use exception to kill the whole call stack and asyncio task
......@@ -805,7 +809,11 @@ class TokenizerManager:
if obj.stream:
yield out
else:
if request is not None and await request.is_disconnected():
if (
request is not None
and not obj.background
and await request.is_disconnected()
):
# Abort the request for disconnected requests (non-streaming, running)
self.abort_request(obj.rid)
# Use exception to kill the whole call stack and asyncio task
......@@ -1548,8 +1556,17 @@ class TokenizerManager:
if isinstance(recv_obj, BatchStrOut):
state.text += recv_obj.output_strs[i]
if state.obj.stream:
state.output_ids.extend(recv_obj.output_ids[i])
output_token_ids = state.output_ids[state.last_output_offset :]
state.last_output_offset = len(state.output_ids)
else:
state.output_ids.extend(recv_obj.output_ids[i])
output_token_ids = state.output_ids.copy()
out_dict = {
"text": state.text,
"output_ids": output_token_ids,
"meta_info": meta_info,
}
elif isinstance(recv_obj, BatchTokenIDOut):
......
......@@ -274,6 +274,9 @@ class ServerArgs:
enable_pdmux: bool = False
sm_group_num: int = 3
# For tool server
tool_server: Optional[str] = None
# Deprecated arguments
enable_ep_moe: bool = False
enable_deepep_moe: bool = False
......@@ -1916,6 +1919,14 @@ class ServerArgs:
help="Disable mmap while loading weight using safetensors.",
)
# For tool server
parser.add_argument(
"--tool-server",
type=str,
default=None,
help="Either 'demo' or a comma-separated list of tool server urls to use for the model. If not specified, no tool server will be used.",
)
# Deprecated arguments
parser.add_argument(
"--enable-ep-moe",
......
......@@ -41,6 +41,7 @@ import tempfile
import threading
import time
import traceback
import uuid
import warnings
from collections import OrderedDict, defaultdict
from contextlib import contextmanager
......@@ -233,6 +234,10 @@ def is_flashinfer_available():
return importlib.util.find_spec("flashinfer") is not None and is_cuda()
def random_uuid() -> str:
return str(uuid.uuid4().hex)
_ENABLE_TORCH_INFERENCE_MODE = get_bool_env_var(
"SGLANG_ENABLE_TORCH_INFERENCE_MODE", "false"
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment