"benchmarks/vscode:/vscode.git/clone" did not exist on "7ee228bf68fe8624da596860949ceb5ebc1b1dfe"
Unverified Commit 72676cd6 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

feat(oai refactor): Replace `openai_api` with `entrypoints/openai` (#7351)


Co-authored-by: default avatarJin Pan <jpan236@wisc.edu>
parent 02bf31ef
......@@ -2,6 +2,7 @@ import json
import logging
from typing import List
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
......@@ -9,7 +10,6 @@ from sglang.srt.function_call.core_types import (
_GetInfoFunc,
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer
from sglang.srt.openai_api.protocol import Tool
logger = logging.getLogger(__name__)
......
......@@ -3,6 +3,7 @@ import logging
import re
from typing import List
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
......@@ -10,7 +11,6 @@ from sglang.srt.function_call.core_types import (
_GetInfoFunc,
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer
from sglang.srt.openai_api.protocol import Tool
logger = logging.getLogger(__name__)
......
......@@ -4,6 +4,7 @@ import logging
import re
from typing import List, Optional
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
......@@ -12,7 +13,6 @@ from sglang.srt.function_call.core_types import (
_GetInfoFunc,
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer
from sglang.srt.openai_api.protocol import Tool
logger = logging.getLogger(__name__)
......
......@@ -3,6 +3,7 @@ import logging
import re
from typing import List
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
......@@ -10,7 +11,6 @@ from sglang.srt.function_call.core_types import (
_GetInfoFunc,
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer
from sglang.srt.openai_api.protocol import Tool
logger = logging.getLogger(__name__)
......
"""
Utility functions for OpenAI API adapter.
"""Template utilities for Jinja template processing.
This module provides utilities for analyzing and processing Jinja chat templates,
including content format detection and message processing.
"""
import logging
from typing import Dict, List
import jinja2.nodes
import jinja2
import transformers.utils.chat_template_utils as hf_chat_utils
logger = logging.getLogger(__name__)
......@@ -75,7 +76,7 @@ def _try_extract_ast(chat_template: str):
return None
def detect_template_content_format(chat_template: str) -> str:
def detect_jinja_template_content_format(chat_template: str) -> str:
"""
Detect whether a chat template expects 'string' or 'openai' content format.
......
......@@ -864,12 +864,6 @@ class SetInternalStateReq:
server_args: Dict[str, Any]
@dataclass
class V1RerankReqInput:
query: str
documents: List[str]
@dataclass
class SetInternalStateReqOutput:
updated: bool
......
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Centralized template management for chat templates and completion templates.
This module provides a unified interface for managing both chat conversation templates
and code completion templates, eliminating global state and improving modularity.
"""
import json
import logging
import os
from typing import Optional
from sglang.srt.code_completion_parser import (
CompletionTemplate,
FimPosition,
completion_template_exists,
register_completion_template,
)
from sglang.srt.conversation import (
Conversation,
SeparatorStyle,
chat_template_exists,
get_conv_template_by_model_path,
register_conv_template,
)
from sglang.srt.jinja_template_utils import detect_jinja_template_content_format
logger = logging.getLogger(__name__)
class TemplateManager:
"""
Centralized manager for chat and completion templates.
This class encapsulates all template-related state and operations,
eliminating the need for global variables and providing a clean
interface for template management.
"""
def __init__(self):
self._chat_template_name: Optional[str] = None
self._completion_template_name: Optional[str] = None
self._jinja_template_content_format: Optional[str] = None
@property
def chat_template_name(self) -> Optional[str]:
"""Get the current chat template name."""
return self._chat_template_name
@property
def completion_template_name(self) -> Optional[str]:
"""Get the current completion template name."""
return self._completion_template_name
@property
def jinja_template_content_format(self) -> Optional[str]:
"""Get the detected template content format ('string' or 'openai' or None)."""
return self._jinja_template_content_format
def load_chat_template(
self, tokenizer_manager, chat_template_arg: str, model_path: str
) -> None:
"""
Load a chat template from various sources.
Args:
tokenizer_manager: The tokenizer manager instance
chat_template_arg: Template name or file path
model_path: Path to the model
"""
logger.info(f"Loading chat template: {chat_template_arg}")
if not chat_template_exists(chat_template_arg):
if not os.path.exists(chat_template_arg):
raise RuntimeError(
f"Chat template {chat_template_arg} is not a built-in template name "
"or a valid chat template file path."
)
if chat_template_arg.endswith(".jinja"):
self._load_jinja_template(tokenizer_manager, chat_template_arg)
else:
self._load_json_chat_template(chat_template_arg)
else:
self._chat_template_name = chat_template_arg
def guess_chat_template_from_model_path(self, model_path: str) -> None:
"""
Infer chat template name from model path.
Args:
model_path: Path to the model
"""
template_name = get_conv_template_by_model_path(model_path)
if template_name is not None:
logger.info(f"Inferred chat template from model path: {template_name}")
self._chat_template_name = template_name
def load_completion_template(self, completion_template_arg: str) -> None:
"""
Load completion template for code completion.
Args:
completion_template_arg: Template name or file path
"""
logger.info(f"Loading completion template: {completion_template_arg}")
if not completion_template_exists(completion_template_arg):
if not os.path.exists(completion_template_arg):
raise RuntimeError(
f"Completion template {completion_template_arg} is not a built-in template name "
"or a valid completion template file path."
)
self._load_json_completion_template(completion_template_arg)
else:
self._completion_template_name = completion_template_arg
def initialize_templates(
self,
tokenizer_manager,
model_path: str,
chat_template: Optional[str] = None,
completion_template: Optional[str] = None,
) -> None:
"""
Initialize all templates based on provided configuration.
Args:
tokenizer_manager: The tokenizer manager instance
model_path: Path to the model
chat_template: Optional chat template name/path
completion_template: Optional completion template name/path
"""
# Load chat template
if chat_template:
self.load_chat_template(tokenizer_manager, chat_template, model_path)
else:
self.guess_chat_template_from_model_path(model_path)
# Load completion template
if completion_template:
self.load_completion_template(completion_template)
def _load_jinja_template(self, tokenizer_manager, template_path: str) -> None:
"""Load a Jinja template file."""
with open(template_path, "r") as f:
chat_template = "".join(f.readlines()).strip("\n")
tokenizer_manager.tokenizer.chat_template = chat_template.replace("\\n", "\n")
self._chat_template_name = None
# Detect content format from the loaded template
self._jinja_template_content_format = detect_jinja_template_content_format(
chat_template
)
logger.info(
f"Detected chat template content format: {self._jinja_template_content_format}"
)
def _load_json_chat_template(self, template_path: str) -> None:
"""Load a JSON chat template file."""
assert template_path.endswith(
".json"
), "unrecognized format of chat template file"
with open(template_path, "r") as filep:
template = json.load(filep)
try:
sep_style = SeparatorStyle[template["sep_style"]]
except KeyError:
raise ValueError(
f"Unknown separator style: {template['sep_style']}"
) from None
register_conv_template(
Conversation(
name=template["name"],
system_template=template["system"] + "\n{system_message}",
system_message=template.get("system_message", ""),
roles=(template["user"], template["assistant"]),
sep_style=sep_style,
sep=template.get("sep", "\n"),
stop_str=template["stop_str"],
),
override=True,
)
self._chat_template_name = template["name"]
def _load_json_completion_template(self, template_path: str) -> None:
"""Load a JSON completion template file."""
assert template_path.endswith(
".json"
), "unrecognized format of completion template file"
with open(template_path, "r") as filep:
template = json.load(filep)
try:
fim_position = FimPosition[template["fim_position"]]
except KeyError:
raise ValueError(
f"Unknown fim position: {template['fim_position']}"
) from None
register_completion_template(
CompletionTemplate(
name=template["name"],
fim_begin_token=template["fim_begin_token"],
fim_middle_token=template["fim_middle_token"],
fim_end_token=template["fim_end_token"],
fim_position=fim_position,
),
override=True,
)
self._completion_template_name = template["name"]
......@@ -1058,12 +1058,7 @@ class TokenizerManager:
"lora_path",
]
)
out_skip_names = set(
[
"text",
"output_ids",
]
)
out_skip_names = set(["text", "output_ids", "embedding"])
elif self.log_requests_level == 1:
max_length = 2048
elif self.log_requests_level == 2:
......
This diff is collapsed.
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Pydantic models for OpenAI API protocol"""
import time
from typing import Dict, List, Optional, Union
from pydantic import BaseModel, Field, model_serializer, root_validator
from typing_extensions import Literal
class ModelCard(BaseModel):
"""Model cards."""
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "sglang"
root: Optional[str] = None
max_model_len: Optional[int] = None
class ModelList(BaseModel):
"""Model list consists of model cards."""
object: str = "list"
data: List[ModelCard] = Field(default_factory=list)
class ErrorResponse(BaseModel):
object: str = "error"
message: str
type: str
param: Optional[str] = None
code: int
class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
class TopLogprob(BaseModel):
token: str
bytes: List[int]
logprob: float
class ChatCompletionTokenLogprob(BaseModel):
token: str
bytes: List[int]
logprob: float
top_logprobs: List[TopLogprob]
class ChoiceLogprobs(BaseModel):
# build for v1/chat/completions response
content: List[ChatCompletionTokenLogprob]
class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
# only used to return cached tokens when --enable-cache-report is set
prompt_tokens_details: Optional[Dict[str, int]] = None
class StreamOptions(BaseModel):
include_usage: Optional[bool] = False
class JsonSchemaResponseFormat(BaseModel):
name: str
description: Optional[str] = None
# use alias to workaround pydantic conflict
schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
strict: Optional[bool] = False
class FileRequest(BaseModel):
# https://platform.openai.com/docs/api-reference/files/create
file: bytes # The File object (not file name) to be uploaded
purpose: str = (
"batch" # The intended purpose of the uploaded file, default is "batch"
)
class FileResponse(BaseModel):
id: str
object: str = "file"
bytes: int
created_at: int
filename: str
purpose: str
class FileDeleteResponse(BaseModel):
id: str
object: str = "file"
deleted: bool
class BatchRequest(BaseModel):
input_file_id: (
str # The ID of an uploaded file that contains requests for the new batch
)
endpoint: str # The endpoint to be used for all requests in the batch
completion_window: str # The time frame within which the batch should be processed
metadata: Optional[dict] = None # Optional custom metadata for the batch
class BatchResponse(BaseModel):
id: str
object: str = "batch"
endpoint: str
errors: Optional[dict] = None
input_file_id: str
completion_window: str
status: str = "validating"
output_file_id: Optional[str] = None
error_file_id: Optional[str] = None
created_at: int
in_progress_at: Optional[int] = None
expires_at: Optional[int] = None
finalizing_at: Optional[int] = None
completed_at: Optional[int] = None
failed_at: Optional[int] = None
expired_at: Optional[int] = None
cancelling_at: Optional[int] = None
cancelled_at: Optional[int] = None
request_counts: Optional[dict] = None
metadata: Optional[dict] = None
class CompletionRequest(BaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/completions/create
model: str
prompt: Union[List[int], List[List[int]], str, List[str]]
best_of: Optional[int] = None
echo: bool = False
frequency_penalty: float = 0.0
logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[int] = None
max_tokens: int = 16
n: int = 1
presence_penalty: float = 0.0
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
stream: bool = False
stream_options: Optional[StreamOptions] = None
suffix: Optional[str] = None
temperature: float = 1.0
top_p: float = 1.0
user: Optional[str] = None
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
top_k: int = -1
min_p: float = 0.0
min_tokens: int = 0
json_schema: Optional[str] = None
regex: Optional[str] = None
ebnf: Optional[str] = None
repetition_penalty: float = 1.0
stop_token_ids: Optional[List[int]] = None
no_stop_trim: bool = False
ignore_eos: bool = False
skip_special_tokens: bool = True
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
session_params: Optional[Dict] = None
return_hidden_states: Optional[bool] = False
# For PD disaggregation
bootstrap_host: Optional[str] = None
bootstrap_port: Optional[int] = None
bootstrap_room: Optional[int] = None
class CompletionResponseChoice(BaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Literal["stop", "length", "content_filter", "abort"]
matched_stop: Union[None, int, str] = None
hidden_states: Optional[object] = None
@model_serializer
def _serialize(self):
return exclude_if_none(self, ["hidden_states"])
class CompletionResponse(BaseModel):
id: str
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseChoice]
usage: UsageInfo
class CompletionResponseStreamChoice(BaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
matched_stop: Union[None, int, str] = None
hidden_states: Optional[object] = None
@model_serializer
def _serialize(self):
return exclude_if_none(self, ["hidden_states"])
class CompletionStreamResponse(BaseModel):
id: str
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseStreamChoice]
usage: Optional[UsageInfo] = None
class ChatCompletionMessageContentTextPart(BaseModel):
type: Literal["text"]
text: str
class ChatCompletionMessageContentImageURL(BaseModel):
url: str
detail: Optional[Literal["auto", "low", "high"]] = "auto"
class ChatCompletionMessageContentAudioURL(BaseModel):
url: str
class ChatCompletionMessageContentImagePart(BaseModel):
type: Literal["image_url"]
image_url: ChatCompletionMessageContentImageURL
modalities: Optional[Literal["image", "multi-images", "video"]] = "image"
class ChatCompletionMessageContentAudioPart(BaseModel):
type: Literal["audio_url"]
audio_url: ChatCompletionMessageContentAudioURL
ChatCompletionMessageContentPart = Union[
ChatCompletionMessageContentTextPart,
ChatCompletionMessageContentImagePart,
ChatCompletionMessageContentAudioPart,
]
class FunctionResponse(BaseModel):
"""Function response."""
name: Optional[str] = None
arguments: Optional[str] = None
class ToolCall(BaseModel):
"""Tool call response."""
id: Optional[str] = None
index: Optional[int] = None
type: Literal["function"] = "function"
function: FunctionResponse
class ChatCompletionMessageGenericParam(BaseModel):
role: Literal["system", "assistant", "tool"]
content: Union[str, List[ChatCompletionMessageContentTextPart], None]
tool_call_id: Optional[str] = None
name: Optional[str] = None
reasoning_content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
class ChatCompletionMessageUserParam(BaseModel):
role: Literal["user"]
content: Union[str, List[ChatCompletionMessageContentPart]]
ChatCompletionMessageParam = Union[
ChatCompletionMessageGenericParam, ChatCompletionMessageUserParam
]
class ResponseFormat(BaseModel):
type: Literal["text", "json_object", "json_schema"]
json_schema: Optional[JsonSchemaResponseFormat] = None
class StructuresResponseFormat(BaseModel):
begin: str
schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
end: str
class StructuralTagResponseFormat(BaseModel):
type: Literal["structural_tag"]
structures: List[StructuresResponseFormat]
triggers: List[str]
class Function(BaseModel):
"""Function descriptions."""
description: Optional[str] = Field(default=None, examples=[None])
name: Optional[str] = None
parameters: Optional[object] = None
strict: bool = False
class Tool(BaseModel):
"""Function wrapper."""
type: str = Field(default="function", examples=["function"])
function: Function
class ToolChoiceFuncName(BaseModel):
"""The name of tool choice function."""
name: Optional[str] = None
class ToolChoice(BaseModel):
"""The tool choice definition."""
function: ToolChoiceFuncName
type: Literal["function"] = Field(default="function", examples=["function"])
class ChatCompletionRequest(BaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create
messages: List[ChatCompletionMessageParam]
model: str
frequency_penalty: float = 0.0
logit_bias: Optional[Dict[str, float]] = None
logprobs: bool = False
top_logprobs: Optional[int] = None
max_tokens: Optional[int] = Field(
default=None,
deprecated="max_tokens is deprecated in favor of the max_completion_tokens field",
description="The maximum number of tokens that can be generated in the chat completion. ",
)
max_completion_tokens: Optional[int] = Field(
default=None,
description="The maximum number of completion tokens for a chat completion request, "
"including visible output tokens and reasoning tokens. Input tokens are not included. ",
)
n: int = 1
presence_penalty: float = 0.0
response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
stream: bool = False
stream_options: Optional[StreamOptions] = None
temperature: float = 0.7
top_p: float = 1.0
user: Optional[str] = None
tools: Optional[List[Tool]] = Field(default=None, examples=[None])
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field(
default="auto", examples=["none"]
) # noqa
@root_validator(pre=True)
def set_tool_choice_default(cls, values):
if values.get("tool_choice") is None:
if values.get("tools") is None:
values["tool_choice"] = "none"
else:
values["tool_choice"] = "auto"
return values
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
top_k: int = -1
min_p: float = 0.0
min_tokens: int = 0
regex: Optional[str] = None
ebnf: Optional[str] = None
repetition_penalty: float = 1.0
stop_token_ids: Optional[List[int]] = None
no_stop_trim: bool = False
ignore_eos: bool = False
continue_final_message: bool = False
skip_special_tokens: bool = True
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
session_params: Optional[Dict] = None
separate_reasoning: bool = True
stream_reasoning: bool = True
chat_template_kwargs: Optional[Dict] = None
# The request id.
rid: Optional[str] = None
# For PD disaggregation
bootstrap_host: Optional[str] = None
bootstrap_port: Optional[int] = None
bootstrap_room: Optional[int] = None
# Hidden States
return_hidden_states: Optional[bool] = False
class ChatMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
reasoning_content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
finish_reason: Literal[
"stop", "length", "tool_calls", "content_filter", "function_call", "abort"
]
matched_stop: Union[None, int, str] = None
hidden_states: Optional[object] = None
@model_serializer
def _serialize(self):
return exclude_if_none(self, ["hidden_states"])
class ChatCompletionResponse(BaseModel):
id: str
object: str = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseChoice]
usage: UsageInfo
class DeltaMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
reasoning_content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
hidden_states: Optional[object] = None
@model_serializer
def _serialize(self):
return exclude_if_none(self, ["hidden_states"])
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
finish_reason: Optional[
Literal["stop", "length", "tool_calls", "content_filter", "function_call"]
] = None
matched_stop: Union[None, int, str] = None
class ChatCompletionStreamResponse(BaseModel):
id: str
object: str = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]
usage: Optional[UsageInfo] = None
class MultimodalEmbeddingInput(BaseModel):
text: Optional[str] = None
image: Optional[str] = None
class EmbeddingRequest(BaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/embeddings/create
input: Union[
List[int], List[List[int]], str, List[str], List[MultimodalEmbeddingInput]
]
model: str
encoding_format: str = "float"
dimensions: int = None
user: Optional[str] = None
# The request id.
rid: Optional[str] = None
class EmbeddingObject(BaseModel):
embedding: List[float]
index: int
object: str = "embedding"
class EmbeddingResponse(BaseModel):
data: List[EmbeddingObject]
model: str
object: str = "list"
usage: Optional[UsageInfo] = None
class ScoringRequest(BaseModel):
query: Optional[Union[str, List[int]]] = (
None # Query text or pre-tokenized token IDs
)
items: Optional[Union[str, List[str], List[List[int]]]] = (
None # Item text(s) or pre-tokenized token IDs
)
label_token_ids: Optional[List[int]] = (
None # Token IDs to compute probabilities for
)
apply_softmax: bool = False
item_first: bool = False
model: str
class ScoringResponse(BaseModel):
scores: List[
List[float]
] # List of lists of probabilities, each in the order of label_token_ids
model: str
usage: Optional[UsageInfo] = None
object: str = "scoring"
class RerankResponse(BaseModel):
score: float
document: str
index: int
meta_info: Optional[dict] = None
def exclude_if_none(obj, field_names: List[str]):
omit_if_none_fields = {k for k, v in obj.model_fields.items() if k in field_names}
return {k: v for k, v in obj if k not in omit_if_none_fields or v is not None}
from typing import Dict, Tuple
from typing import Dict, Optional, Tuple, Type
class StreamingParseResult:
......@@ -32,17 +32,26 @@ class BaseReasoningFormatDetector:
One-time parsing: Detects and parses reasoning sections in the provided text.
Returns both reasoning content and normal text separately.
"""
text = text.replace(self.think_start_token, "").strip()
if self.think_end_token not in text:
in_reasoning = self._in_reasoning or text.startswith(self.think_start_token)
if not in_reasoning:
return StreamingParseResult(normal_text=text)
# The text is considered to be in a reasoning block.
processed_text = text.replace(self.think_start_token, "").strip()
if self.think_end_token not in processed_text:
# Assume reasoning was truncated before `</think>` token
return StreamingParseResult(reasoning_text=text)
return StreamingParseResult(reasoning_text=processed_text)
# Extract reasoning content
splits = text.split(self.think_end_token, maxsplit=1)
splits = processed_text.split(self.think_end_token, maxsplit=1)
reasoning_text = splits[0]
text = splits[1].strip()
normal_text = splits[1].strip()
return StreamingParseResult(normal_text=text, reasoning_text=reasoning_text)
return StreamingParseResult(
normal_text=normal_text, reasoning_text=reasoning_text
)
def parse_streaming_increment(self, new_text: str) -> StreamingParseResult:
"""
......@@ -61,6 +70,7 @@ class BaseReasoningFormatDetector:
if not self.stripped_think_start and self.think_start_token in current_text:
current_text = current_text.replace(self.think_start_token, "")
self.stripped_think_start = True
self._in_reasoning = True
# Handle end of reasoning block
if self._in_reasoning and self.think_end_token in current_text:
......@@ -131,11 +141,11 @@ class Qwen3Detector(BaseReasoningFormatDetector):
"""
def __init__(self, stream_reasoning: bool = True):
# Qwen3 is assumed to be reasoning until `</think>` token
# Qwen3 won't be in reasoning mode when user passes `enable_thinking=False`
super().__init__(
"<think>",
"</think>",
force_reasoning=True,
force_reasoning=False,
stream_reasoning=stream_reasoning,
)
......@@ -151,12 +161,12 @@ class ReasoningParser:
If True, streams reasoning content as it arrives.
"""
DetectorMap: Dict[str, BaseReasoningFormatDetector] = {
DetectorMap: Dict[str, Type[BaseReasoningFormatDetector]] = {
"deepseek-r1": DeepSeekR1Detector,
"qwen3": Qwen3Detector,
}
def __init__(self, model_type: str = None, stream_reasoning: bool = True):
def __init__(self, model_type: Optional[str] = None, stream_reasoning: bool = True):
if not model_type:
raise ValueError("Model type must be specified")
......
# sglang/test/srt/openai/conftest.py
import os
import socket
import subprocess
import sys
import tempfile
import time
from contextlib import closing
from typing import Generator
import pytest
import requests
from sglang.srt.utils import kill_process_tree # reuse SGLang helper
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
SERVER_MODULE = "sglang.srt.entrypoints.openai.api_server"
DEFAULT_MODEL = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
STARTUP_TIMEOUT = float(os.getenv("SGLANG_OPENAI_STARTUP_TIMEOUT", 120))
def _pick_free_port() -> int:
with closing(socket.socket()) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
def _wait_until_healthy(proc: subprocess.Popen, base: str, timeout: float) -> None:
start = time.perf_counter()
while time.perf_counter() - start < timeout:
if proc.poll() is not None: # crashed
raise RuntimeError("api_server terminated prematurely")
try:
if requests.get(f"{base}/health", timeout=1).status_code == 200:
return
except requests.RequestException:
pass
time.sleep(0.4)
raise RuntimeError("api_server readiness probe timed out")
def launch_openai_server(model: str = DEFAULT_MODEL, **kw):
"""Spawn the draft OpenAI-compatible server and wait until it's ready."""
port = _pick_free_port()
cmd = [
sys.executable,
"-m",
SERVER_MODULE,
"--model-path",
model,
"--host",
"127.0.0.1",
"--port",
str(port),
*map(str, kw.get("args", [])),
]
env = {**os.environ, **kw.get("env", {})}
# Write logs to a temp file so the child never blocks on a full pipe.
log_file = tempfile.NamedTemporaryFile("w+", delete=False)
proc = subprocess.Popen(
cmd,
env=env,
stdout=log_file,
stderr=subprocess.STDOUT,
text=True,
)
base = f"http://127.0.0.1:{port}"
try:
_wait_until_healthy(proc, base, STARTUP_TIMEOUT)
except Exception as e:
proc.terminate()
proc.wait(5)
log_file.seek(0)
print("\n--- api_server log ---\n", log_file.read(), file=sys.stderr)
raise e
return proc, base, log_file
@pytest.fixture(scope="session")
def openai_server() -> Generator[str, None, None]:
"""PyTest fixture that provides the server's base URL and cleans up."""
proc, base, log_file = launch_openai_server()
yield base
kill_process_tree(proc.pid)
log_file.close()
This diff is collapsed.
# sglang/test/srt/openai/test_server.py
import requests
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST as MODEL_ID
def test_health(openai_server: str):
r = requests.get(f"{openai_server}/health")
assert r.status_code == 200
# FastAPI returns an empty body → r.text == ""
assert r.text == ""
def test_models_endpoint(openai_server: str):
r = requests.get(f"{openai_server}/v1/models")
assert r.status_code == 200, r.text
payload = r.json()
# Basic contract
assert "data" in payload and isinstance(payload["data"], list) and payload["data"]
# Validate fields of the first model card
first = payload["data"][0]
for key in ("id", "root", "max_model_len"):
assert key in first, f"missing {key} in {first}"
# max_model_len must be positive
assert isinstance(first["max_model_len"], int) and first["max_model_len"] > 0
# The server should report the same model id we launched it with
ids = {m["id"] for m in payload["data"]}
assert MODEL_ID in ids
def test_get_model_info(openai_server: str):
r = requests.get(f"{openai_server}/get_model_info")
assert r.status_code == 200, r.text
info = r.json()
expected_keys = {"model_path", "tokenizer_path", "is_generation"}
assert expected_keys.issubset(info.keys())
# model_path must end with the one we passed on the CLI
assert info["model_path"].endswith(MODEL_ID)
# is_generation is documented as a boolean
assert isinstance(info["is_generation"], bool)
def test_unknown_route_returns_404(openai_server: str):
r = requests.get(f"{openai_server}/definitely-not-a-real-route")
assert r.status_code == 404
This diff is collapsed.
......@@ -5,6 +5,7 @@ Run with:
"""
import unittest
from typing import Optional
from unittest.mock import AsyncMock, Mock, patch
from sglang.srt.entrypoints.openai.protocol import CompletionRequest
......@@ -12,6 +13,17 @@ from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompl
from sglang.srt.managers.tokenizer_manager import TokenizerManager
class _MockTemplateManager:
"""Minimal mock for TemplateManager."""
def __init__(self):
self.chat_template_name: Optional[str] = None
self.jinja_template_content_format: Optional[str] = None
self.completion_template_name: Optional[str] = (
None # Set to None to avoid template processing
)
class ServingCompletionTestCase(unittest.TestCase):
"""Bundle all prompt/echo tests in one TestCase."""
......@@ -31,7 +43,8 @@ class ServingCompletionTestCase(unittest.TestCase):
tm.generate_request = AsyncMock()
tm.create_abort_task = Mock()
self.sc = OpenAIServingCompletion(tm)
self.template_manager = _MockTemplateManager()
self.sc = OpenAIServingCompletion(tm, self.template_manager)
# ---------- prompt-handling ----------
def test_single_string_prompt(self):
......@@ -44,20 +57,6 @@ class ServingCompletionTestCase(unittest.TestCase):
internal, _ = self.sc._convert_to_internal_request(req)
self.assertEqual(internal.input_ids, [1, 2, 3, 4])
def test_completion_template_handling(self):
req = CompletionRequest(
model="x", prompt="def f():", suffix="return 1", max_tokens=100
)
with patch(
"sglang.srt.entrypoints.openai.serving_completions.is_completion_template_defined",
return_value=True,
), patch(
"sglang.srt.entrypoints.openai.serving_completions.generate_completion_prompt_from_request",
return_value="processed_prompt",
):
internal, _ = self.sc._convert_to_internal_request(req)
self.assertEqual(internal.text, "processed_prompt")
# ---------- echo-handling ----------
def test_echo_with_string_prompt_streaming(self):
req = CompletionRequest(model="x", prompt="Hello", max_tokens=1, echo=True)
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment