Unverified Commit 70c471a8 authored by Xinyuan Tong's avatar Xinyuan Tong Committed by GitHub
Browse files

[Refactor] OAI Server components (#7167)


Signed-off-by: default avatarXinyuan Tong <justinning0323@outlook.com>
parent 1a9c2c92
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Pydantic models for OpenAI API protocol"""
import time
from typing import Dict, List, Optional, Union
from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Literal
class ModelCard(BaseModel):
"""Model cards."""
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "sglang"
root: Optional[str] = None
max_model_len: Optional[int] = None
class ModelList(BaseModel):
"""Model list consists of model cards."""
object: str = "list"
data: List[ModelCard] = Field(default_factory=list)
class ErrorResponse(BaseModel):
object: str = "error"
message: str
type: str
param: Optional[str] = None
code: int
class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
class TopLogprob(BaseModel):
token: str
bytes: List[int]
logprob: float
class ChatCompletionTokenLogprob(BaseModel):
token: str
bytes: List[int]
logprob: float
top_logprobs: List[TopLogprob]
class ChoiceLogprobs(BaseModel):
# build for v1/chat/completions response
content: List[ChatCompletionTokenLogprob]
class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
# only used to return cached tokens when --enable-cache-report is set
prompt_tokens_details: Optional[Dict[str, int]] = None
class StreamOptions(BaseModel):
include_usage: Optional[bool] = False
class JsonSchemaResponseFormat(BaseModel):
name: str
description: Optional[str] = None
# use alias to workaround pydantic conflict
schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
strict: Optional[bool] = False
class FileRequest(BaseModel):
# https://platform.openai.com/docs/api-reference/files/create
file: bytes # The File object (not file name) to be uploaded
purpose: str = (
"batch" # The intended purpose of the uploaded file, default is "batch"
)
class FileResponse(BaseModel):
id: str
object: str = "file"
bytes: int
created_at: int
filename: str
purpose: str
class FileDeleteResponse(BaseModel):
id: str
object: str = "file"
deleted: bool
class BatchRequest(BaseModel):
input_file_id: (
str # The ID of an uploaded file that contains requests for the new batch
)
endpoint: str # The endpoint to be used for all requests in the batch
completion_window: str # The time frame within which the batch should be processed
metadata: Optional[dict] = None # Optional custom metadata for the batch
class BatchResponse(BaseModel):
id: str
object: str = "batch"
endpoint: str
errors: Optional[dict] = None
input_file_id: str
completion_window: str
status: str = "validating"
output_file_id: Optional[str] = None
error_file_id: Optional[str] = None
created_at: int
in_progress_at: Optional[int] = None
expires_at: Optional[int] = None
finalizing_at: Optional[int] = None
completed_at: Optional[int] = None
failed_at: Optional[int] = None
expired_at: Optional[int] = None
cancelling_at: Optional[int] = None
cancelled_at: Optional[int] = None
request_counts: Optional[dict] = None
metadata: Optional[dict] = None
class CompletionRequest(BaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/completions/create
model: str
prompt: Union[List[int], List[List[int]], str, List[str]]
best_of: Optional[int] = None
echo: bool = False
frequency_penalty: float = 0.0
logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[int] = None
max_tokens: int = 16
n: int = 1
presence_penalty: float = 0.0
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
stream: bool = False
stream_options: Optional[StreamOptions] = None
suffix: Optional[str] = None
temperature: float = 1.0
top_p: float = 1.0
user: Optional[str] = None
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
top_k: int = -1
min_p: float = 0.0
min_tokens: int = 0
json_schema: Optional[str] = None
regex: Optional[str] = None
ebnf: Optional[str] = None
repetition_penalty: float = 1.0
stop_token_ids: Optional[List[int]] = None
no_stop_trim: bool = False
ignore_eos: bool = False
skip_special_tokens: bool = True
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
session_params: Optional[Dict] = None
# For PD disaggregation
bootstrap_host: Optional[str] = None
bootstrap_port: Optional[int] = None
bootstrap_room: Optional[int] = None
@field_validator("max_tokens")
@classmethod
def validate_max_tokens_positive(cls, v):
if v is not None and v <= 0:
raise ValueError("max_tokens must be positive")
return v
class CompletionResponseChoice(BaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Literal["stop", "length", "content_filter", "abort"]
matched_stop: Union[None, int, str] = None
class CompletionResponse(BaseModel):
id: str
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseChoice]
usage: UsageInfo
class CompletionResponseStreamChoice(BaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
matched_stop: Union[None, int, str] = None
class CompletionStreamResponse(BaseModel):
id: str
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseStreamChoice]
usage: Optional[UsageInfo] = None
class ChatCompletionMessageContentTextPart(BaseModel):
type: Literal["text"]
text: str
class ChatCompletionMessageContentImageURL(BaseModel):
url: str
detail: Optional[Literal["auto", "low", "high"]] = "auto"
class ChatCompletionMessageContentAudioURL(BaseModel):
url: str
class ChatCompletionMessageContentImagePart(BaseModel):
type: Literal["image_url"]
image_url: ChatCompletionMessageContentImageURL
modalities: Optional[Literal["image", "multi-images", "video"]] = "image"
class ChatCompletionMessageContentAudioPart(BaseModel):
type: Literal["audio_url"]
audio_url: ChatCompletionMessageContentAudioURL
ChatCompletionMessageContentPart = Union[
ChatCompletionMessageContentTextPart,
ChatCompletionMessageContentImagePart,
ChatCompletionMessageContentAudioPart,
]
class FunctionResponse(BaseModel):
"""Function response."""
name: Optional[str] = None
arguments: Optional[str] = None
class ToolCall(BaseModel):
"""Tool call response."""
id: Optional[str] = None
index: Optional[int] = None
type: Literal["function"] = "function"
function: FunctionResponse
class ChatCompletionMessageGenericParam(BaseModel):
role: Literal["system", "assistant", "tool"]
content: Union[str, List[ChatCompletionMessageContentTextPart], None]
tool_call_id: Optional[str] = None
name: Optional[str] = None
reasoning_content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
class ChatCompletionMessageUserParam(BaseModel):
role: Literal["user"]
content: Union[str, List[ChatCompletionMessageContentPart]]
ChatCompletionMessageParam = Union[
ChatCompletionMessageGenericParam, ChatCompletionMessageUserParam
]
class ResponseFormat(BaseModel):
type: Literal["text", "json_object", "json_schema"]
json_schema: Optional[JsonSchemaResponseFormat] = None
class StructuresResponseFormat(BaseModel):
begin: str
schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
end: str
class StructuralTagResponseFormat(BaseModel):
type: Literal["structural_tag"]
structures: List[StructuresResponseFormat]
triggers: List[str]
class Function(BaseModel):
"""Function descriptions."""
description: Optional[str] = Field(default=None, examples=[None])
name: Optional[str] = None
parameters: Optional[object] = None
strict: bool = False
class Tool(BaseModel):
"""Function wrapper."""
type: str = Field(default="function", examples=["function"])
function: Function
class ToolChoiceFuncName(BaseModel):
"""The name of tool choice function."""
name: Optional[str] = None
class ToolChoice(BaseModel):
"""The tool choice definition."""
function: ToolChoiceFuncName
type: Literal["function"] = Field(default="function", examples=["function"])
class ChatCompletionRequest(BaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create
messages: List[ChatCompletionMessageParam]
model: str
frequency_penalty: float = 0.0
logit_bias: Optional[Dict[str, float]] = None
logprobs: bool = False
top_logprobs: Optional[int] = None
max_tokens: Optional[int] = Field(
default=None,
deprecated="max_tokens is deprecated in favor of the max_completion_tokens field",
description="The maximum number of tokens that can be generated in the chat completion. ",
)
max_completion_tokens: Optional[int] = Field(
default=None,
description="The maximum number of completion tokens for a chat completion request, "
"including visible output tokens and reasoning tokens. Input tokens are not included. ",
)
n: int = 1
presence_penalty: float = 0.0
response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
stream: bool = False
stream_options: Optional[StreamOptions] = None
temperature: float = 0.7
top_p: float = 1.0
user: Optional[str] = None
tools: Optional[List[Tool]] = Field(default=None, examples=[None])
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field(
default="auto", examples=["none"]
) # noqa
@model_validator(mode="before")
@classmethod
def set_tool_choice_default(cls, values):
if isinstance(values, dict):
if values.get("tool_choice") is None:
if values.get("tools") is None:
values["tool_choice"] = "none"
else:
values["tool_choice"] = "auto"
return values
@field_validator("messages")
@classmethod
def validate_messages_not_empty(cls, v):
if not v:
raise ValueError("Messages cannot be empty")
return v
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
top_k: int = -1
min_p: float = 0.0
min_tokens: int = 0
regex: Optional[str] = None
ebnf: Optional[str] = None
repetition_penalty: float = 1.0
stop_token_ids: Optional[List[int]] = None
no_stop_trim: bool = False
ignore_eos: bool = False
continue_final_message: bool = False
skip_special_tokens: bool = True
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
session_params: Optional[Dict] = None
separate_reasoning: bool = True
stream_reasoning: bool = True
chat_template_kwargs: Optional[Dict] = None
# The request id.
rid: Optional[str] = None
# For PD disaggregation
bootstrap_host: Optional[str] = None
bootstrap_port: Optional[int] = None
bootstrap_room: Optional[int] = None
class ChatMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
reasoning_content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
finish_reason: Literal[
"stop", "length", "tool_calls", "content_filter", "function_call", "abort"
]
matched_stop: Union[None, int, str] = None
class ChatCompletionResponse(BaseModel):
id: str
object: str = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseChoice]
usage: UsageInfo
class DeltaMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
reasoning_content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
finish_reason: Optional[
Literal["stop", "length", "tool_calls", "content_filter", "function_call"]
] = None
matched_stop: Union[None, int, str] = None
class ChatCompletionStreamResponse(BaseModel):
id: str
object: str = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]
usage: Optional[UsageInfo] = None
class MultimodalEmbeddingInput(BaseModel):
text: Optional[str] = None
image: Optional[str] = None
EmbeddingInput = Union[
List[int], List[List[int]], str, List[str], List[MultimodalEmbeddingInput]
]
class EmbeddingRequest(BaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/embeddings/create
input: EmbeddingInput
model: str
encoding_format: str = "float"
dimensions: int = None
user: Optional[str] = None
# The request id.
rid: Optional[str] = None
class EmbeddingObject(BaseModel):
embedding: List[float]
index: int
object: str = "embedding"
class EmbeddingResponse(BaseModel):
data: List[EmbeddingObject]
model: str
object: str = "list"
usage: Optional[UsageInfo] = None
class ScoringRequest(BaseModel):
query: Optional[Union[str, List[int]]] = (
None # Query text or pre-tokenized token IDs
)
items: Optional[Union[str, List[str], List[List[int]]]] = (
None # Item text(s) or pre-tokenized token IDs
)
label_token_ids: Optional[List[int]] = (
None # Token IDs to compute probabilities for
)
apply_softmax: bool = False
item_first: bool = False
model: str
class ScoringResponse(BaseModel):
scores: List[
List[float]
] # List of lists of probabilities, each in the order of label_token_ids
model: str
usage: Optional[UsageInfo] = None
object: str = "scoring"
OpenAIServingRequest = Union[
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, ScoringRequest
]
import json
import logging
import uuid
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union
from fastapi import Request
from fastapi.responses import ORJSONResponse, StreamingResponse
from sglang.srt.entrypoints.openai.protocol import (
ErrorResponse,
OpenAIServingRequest,
UsageInfo,
)
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.tokenizer_manager import TokenizerManager
logger = logging.getLogger(__name__)
# Base class for specific endpoint handlers
class OpenAIServingBase(ABC):
"""Abstract base class for OpenAI endpoint handlers"""
def __init__(self, tokenizer_manager: TokenizerManager):
self.tokenizer_manager = tokenizer_manager
async def handle_request(
self, request: OpenAIServingRequest, raw_request: Request
) -> Union[Any, StreamingResponse, ErrorResponse]:
"""Handle the specific request type with common pattern"""
try:
# Validate request
error_msg = self._validate_request(request)
if error_msg:
return self.create_error_response(error_msg)
# Convert to internal format
adapted_request, processed_request = self._convert_to_internal_request(
[request], [self._generate_request_id_base(request)]
)
# Note(Xinyuan): raw_request below is only used for detecting the connection of the client
if hasattr(request, "stream") and request.stream:
return await self._handle_streaming_request(
adapted_request, processed_request, raw_request
)
else:
return await self._handle_non_streaming_request(
adapted_request, processed_request, raw_request
)
except Exception as e:
logger.error(f"Error in request: {e}")
return self.create_error_response(
message=f"Internal server error: {str(e)}",
err_type="InternalServerError",
status_code=500,
)
@abstractmethod
def _request_id_prefix(self) -> str:
"""Generate request ID based on request type"""
pass
def _generate_request_id_base(self, request: OpenAIServingRequest) -> str:
"""Generate request ID based on request type"""
if rid := getattr(request, "rid", None):
return rid
return f"{self._request_id_prefix()}{uuid.uuid4().hex}"
@abstractmethod
def _convert_to_internal_request(
self,
all_requests: List[OpenAIServingRequest],
request_ids: List[str],
) -> tuple[
GenerateReqInput, Union[OpenAIServingRequest, List[OpenAIServingRequest]]
]:
"""Convert OpenAI request to internal format"""
pass
async def _handle_streaming_request(
self,
adapted_request: GenerateReqInput,
request: OpenAIServingRequest,
raw_request: Request,
) -> StreamingResponse:
"""Handle streaming request
Override this method in child classes that support streaming requests.
"""
return self.create_error_response(
message=f"{self.__class__.__name__} does not support streaming requests",
err_type="NotImplementedError",
status_code=501,
)
async def _handle_non_streaming_request(
self,
adapted_request: GenerateReqInput,
request: OpenAIServingRequest,
raw_request: Request,
) -> Union[Any, ErrorResponse]:
"""Handle non-streaming request
Override this method in child classes that support non-streaming requests.
"""
return self.create_error_response(
message=f"{self.__class__.__name__} does not support non-streaming requests",
err_type="NotImplementedError",
status_code=501,
)
def _validate_request(self, request: OpenAIServingRequest) -> Optional[str]:
"""Validate request"""
pass
def _calculate_streaming_usage_base(
self,
prompt_tokens: Dict[int, int],
completion_tokens: Dict[int, int],
cached_tokens: Dict[int, int],
n_choices: int,
) -> UsageInfo:
"""Calculate usage information for streaming responses (common logic)"""
total_prompt_tokens = sum(
tokens for i, tokens in prompt_tokens.items() if i % n_choices == 0
)
total_completion_tokens = sum(tokens for tokens in completion_tokens.values())
cache_report = self.tokenizer_manager.server_args.enable_cache_report
prompt_tokens_details = None
if cache_report:
cached_tokens_sum = sum(tokens for tokens in cached_tokens.values())
if cached_tokens_sum > 0:
prompt_tokens_details = {"cached_tokens": cached_tokens_sum}
return UsageInfo(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens,
prompt_tokens_details=prompt_tokens_details,
)
def create_error_response(
self,
message: str,
err_type: str = "BadRequestError",
status_code: int = 400,
param: Optional[str] = None,
) -> ORJSONResponse:
"""Create an error response"""
error = ErrorResponse(
object="error",
message=message,
type=err_type,
param=param,
code=status_code,
)
return ORJSONResponse(content=error.model_dump(), status_code=status_code)
def create_streaming_error_response(
self,
message: str,
err_type: str = "BadRequestError",
status_code: int = 400,
) -> str:
"""Create a streaming error response"""
error = ErrorResponse(
object="error",
message=message,
type=err_type,
param=None,
code=status_code,
)
return json.dumps({"error": error.model_dump()})
import base64
import json
import logging
import time
import uuid
from typing import Any, Dict, List, Optional, Union
from fastapi import Request
from fastapi.responses import StreamingResponse
from sglang.srt.conversation import generate_chat_conv
from sglang.srt.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatCompletionTokenLogprob,
ChatMessage,
ChoiceLogprobs,
DeltaMessage,
ErrorResponse,
FunctionResponse,
LogProbs,
ToolCall,
TopLogprob,
)
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
from sglang.srt.entrypoints.openai.utils import (
aggregate_token_usage,
detect_template_content_format,
process_content_for_template_format,
to_openai_style_logprobs,
)
from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.reasoning_parser import ReasoningParser
from sglang.utils import convert_json_schema_to_str
logger = logging.getLogger(__name__)
class OpenAIServingChat(OpenAIServingBase):
"""Handler for chat completion requests"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Instance-specific cache for template content format detection
self._cached_chat_template = None
self._cached_template_format = None
def _request_id_prefix(self) -> str:
return "chatcmpl-"
def _validate_request(self, request: ChatCompletionRequest) -> Optional[str]:
"""Validate chat messages format and content"""
if not (messages := request.messages):
return "Messages cannot be empty"
# Check for alternating user/assistant pattern (optional validation)
roles = [msg.role for msg in messages]
# First message should typically be from user or system
if roles[0] not in ["user", "system"]:
return "First message should be from 'user' or 'system'"
# Check for consecutive assistant messages (which might indicate an error)
for i in range(1, len(roles)):
if roles[i] == "assistant" and roles[i - 1] == "assistant":
# This is actually allowed in some cases, so just warn
pass
# Validate message content
for i, msg in enumerate(messages):
if msg.role == "user":
if not msg.content:
return f"User message at index {i} has no content"
elif msg.role == "assistant":
# Assistant messages can have no content if they have tool_calls
if not msg.content and not getattr(msg, "tool_calls", None):
return (
f"Assistant message at index {i} has no content or tool calls"
)
return None
def _convert_to_internal_request(
self,
all_requests: List[ChatCompletionRequest],
request_ids: List[str],
) -> tuple[
GenerateReqInput, Union[ChatCompletionRequest, List[ChatCompletionRequest]]
]:
"""Convert OpenAI chat completion request to internal format"""
input_ids = []
prompts = []
sampling_params_list = []
image_data_list = []
audio_data_list = []
return_logprobs = []
logprob_start_lens = []
top_logprobs_nums = []
modalities_list = []
lora_paths = []
is_multimodal = self.tokenizer_manager.model_config.is_multimodal
for request in all_requests:
# Process messages and apply chat template
(
prompt,
prompt_ids,
image_data,
audio_data,
modalities,
stop,
tool_call_constraint,
) = self._process_messages(request, is_multimodal)
input_ids.append(prompt_ids)
prompts.append(prompt)
return_logprobs.append(request.logprobs)
logprob_start_lens.append(-1)
top_logprobs_nums.append(request.top_logprobs or 0)
lora_paths.append(request.lora_path)
# Build sampling parameters
sampling_params = self._build_sampling_params(
request, stop, tool_call_constraint
)
sampling_params_list.append(sampling_params)
image_data_list.append(image_data)
audio_data_list.append(audio_data)
modalities_list.append(modalities)
# Handle single vs multiple requests
if len(all_requests) == 1:
if is_multimodal:
prompt_kwargs = {"text": prompts[0]}
else:
if isinstance(input_ids[0], str):
prompt_kwargs = {"text": input_ids[0]}
else:
prompt_kwargs = {"input_ids": input_ids[0]}
sampling_params_list = sampling_params_list[0]
image_data_list = image_data_list[0]
audio_data_list = audio_data_list[0]
return_logprobs = return_logprobs[0]
logprob_start_lens = logprob_start_lens[0]
top_logprobs_nums = top_logprobs_nums[0]
modalities_list = modalities_list[0]
lora_paths = lora_paths[0]
request_ids = request_ids[0]
else:
if is_multimodal:
prompt_kwargs = {"text": prompts}
else:
if isinstance(input_ids[0], str):
prompt_kwargs = {"text": input_ids}
else:
prompt_kwargs = {"input_ids": input_ids}
adapted_request = GenerateReqInput(
**prompt_kwargs,
image_data=image_data_list,
audio_data=audio_data_list,
sampling_params=sampling_params_list,
return_logprob=return_logprobs,
logprob_start_len=logprob_start_lens,
top_logprobs_num=top_logprobs_nums,
stream=all_requests[0].stream,
return_text_in_logprobs=True,
rid=request_ids,
modalities=modalities_list,
lora_path=lora_paths,
bootstrap_host=all_requests[0].bootstrap_host,
bootstrap_port=all_requests[0].bootstrap_port,
bootstrap_room=all_requests[0].bootstrap_room,
)
return adapted_request, (
all_requests if len(all_requests) > 1 else all_requests[0]
)
def _process_messages(
self, request: ChatCompletionRequest, is_multimodal: bool
) -> tuple[
str,
Union[str, List[int]],
Optional[Any],
Optional[Any],
List[str],
List[str],
Optional[Any],
]:
"""Process chat messages and apply chat template"""
tool_call_constraint = None
prompt = ""
prompt_ids = []
if not isinstance(request.messages, str):
# Apply chat template and its stop strings
tools = None
if request.tools and request.tool_choice != "none":
request.skip_special_tokens = False
if not isinstance(request.tool_choice, str):
tools = [
item.function.model_dump()
for item in request.tools
if item.function.name == request.tool_choice.function.name
]
else:
tools = [item.function.model_dump() for item in request.tools]
tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
parser = FunctionCallParser(request.tools, tool_call_parser)
tool_call_constraint = parser.get_structure_constraint(
request.tool_choice
)
# Use chat template
if (
hasattr(self.tokenizer_manager, "chat_template_name")
and self.tokenizer_manager.chat_template_name is None
):
prompt, prompt_ids, image_data, audio_data, modalities, stop = (
self._apply_jinja_template(request, tools, is_multimodal)
)
else:
prompt, image_data, audio_data, modalities, stop = (
self._apply_conversation_template(request)
)
if not is_multimodal:
prompt_ids = self.tokenizer_manager.tokenizer.encode(prompt)
else:
# Use raw prompt
prompt_ids = request.messages
stop = request.stop or []
image_data = None
audio_data = None
modalities = []
prompt = request.messages
return (
prompt,
prompt_ids,
image_data,
audio_data,
modalities,
stop,
tool_call_constraint,
)
def _apply_jinja_template(
self,
request: ChatCompletionRequest,
tools: Optional[List[Dict]],
is_multimodal: bool,
) -> tuple[str, List[int], Optional[Any], Optional[Any], List[str], List[str]]:
"""Apply Jinja chat template"""
openai_compatible_messages = []
image_data = []
audio_data = []
modalities = []
# Detect template content format
current_template = self.tokenizer_manager.tokenizer.chat_template
if current_template != self._cached_chat_template:
self._cached_chat_template = current_template
self._cached_template_format = detect_template_content_format(
current_template
)
logger.info(
f"Detected chat template content format: {self._cached_template_format}"
)
template_content_format = self._cached_template_format
for message in request.messages:
if message.content is None:
message.content = ""
msg_dict = message.model_dump()
# Process content based on detected template format
processed_msg = process_content_for_template_format(
msg_dict,
template_content_format,
image_data,
audio_data,
modalities,
)
openai_compatible_messages.append(processed_msg)
# Handle assistant prefix for continue_final_message
assistant_prefix = None
if (
openai_compatible_messages
and openai_compatible_messages[-1]["role"] == "assistant"
):
if request.continue_final_message:
assistant_prefix = openai_compatible_messages[-1]["content"]
openai_compatible_messages = openai_compatible_messages[:-1]
try:
prompt_ids = self.tokenizer_manager.tokenizer.apply_chat_template(
openai_compatible_messages,
tokenize=True,
add_generation_prompt=True,
tools=tools,
**(
request.chat_template_kwargs if request.chat_template_kwargs else {}
),
)
except Exception:
# This except branch will be triggered when the chosen model
# has a different tools input format that is not compatible
# with openAI's apply_chat_template tool_call format, like Mistral.
tools = (
[t if "function" in t else {"function": t} for t in tools]
if tools
else None
)
prompt_ids = self.tokenizer_manager.tokenizer.apply_chat_template(
openai_compatible_messages,
tokenize=True,
add_generation_prompt=True,
tools=tools,
**(
request.chat_template_kwargs if request.chat_template_kwargs else {}
),
)
if assistant_prefix:
encoded = self.tokenizer_manager.tokenizer.encode(assistant_prefix)
if encoded and encoded[0] == self.tokenizer_manager.tokenizer.bos_token_id:
encoded = encoded[1:]
prompt_ids += encoded
if is_multimodal:
prompt = self.tokenizer_manager.tokenizer.decode(prompt_ids)
stop = request.stop or []
return prompt, prompt_ids, image_data, audio_data, modalities, stop
def _apply_conversation_template(
self, request: ChatCompletionRequest
) -> tuple[str, Optional[Any], Optional[Any], List[str], List[str]]:
"""Apply conversation template"""
conv = generate_chat_conv(request, self.tokenizer_manager.chat_template_name)
# If we should continue the final assistant message, adjust the conversation.
if (
request.continue_final_message
and request.messages
and request.messages[-1].role == "assistant"
):
# Remove the auto-added blank assistant turn, if present.
if conv.messages and conv.messages[-1][1] is None:
conv.messages.pop()
# Rebuild the prompt from the conversation.
prompt = conv.get_prompt()
# Strip trailing stop tokens or separators that indicate end-of-assistant.
if isinstance(conv.stop_str, list):
for stop_token in conv.stop_str:
if prompt.endswith(stop_token):
prompt = prompt[: -len(stop_token)]
elif isinstance(conv.stop_str, str) and prompt.endswith(conv.stop_str):
prompt = prompt[: -len(conv.stop_str)]
if conv.sep and prompt.endswith(conv.sep):
prompt = prompt[: -len(conv.sep)]
if getattr(conv, "sep2", None) and prompt.endswith(conv.sep2):
prompt = prompt[: -len(conv.sep2)]
else:
prompt = conv.get_prompt()
image_data = conv.image_data
audio_data = conv.audio_data
modalities = conv.modalities
stop = conv.stop_str or [] if not request.ignore_eos else []
if request.stop:
if isinstance(request.stop, str):
stop.append(request.stop)
else:
stop.extend(request.stop)
return prompt, image_data, audio_data, modalities, stop
def _build_sampling_params(
self,
request: ChatCompletionRequest,
stop: List[str],
tool_call_constraint: Optional[Any],
) -> Dict[str, Any]:
"""Build sampling parameters for the request"""
sampling_params = {
"temperature": request.temperature,
"max_new_tokens": request.max_tokens or request.max_completion_tokens,
"min_new_tokens": request.min_tokens,
"stop": stop,
"stop_token_ids": request.stop_token_ids,
"top_p": request.top_p,
"top_k": request.top_k,
"min_p": request.min_p,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
"repetition_penalty": request.repetition_penalty,
"regex": request.regex,
"ebnf": request.ebnf,
"n": request.n,
"no_stop_trim": request.no_stop_trim,
"ignore_eos": request.ignore_eos,
"skip_special_tokens": request.skip_special_tokens,
"logit_bias": request.logit_bias,
}
if request.response_format and request.response_format.type == "json_schema":
sampling_params["json_schema"] = convert_json_schema_to_str(
request.response_format.json_schema.schema_
)
elif request.response_format and request.response_format.type == "json_object":
sampling_params["json_schema"] = '{"type": "object"}'
elif (
request.response_format and request.response_format.type == "structural_tag"
):
sampling_params["structural_tag"] = convert_json_schema_to_str(
request.response_format.model_dump(by_alias=True)
)
# Check if there are already existing output constraints
has_existing_constraints = (
sampling_params.get("regex")
or sampling_params.get("ebnf")
or sampling_params.get("structural_tag")
or sampling_params.get("json_schema")
)
if tool_call_constraint and has_existing_constraints:
logger.warning("Constrained decoding is not compatible with tool calls.")
elif tool_call_constraint:
constraint_type, constraint_value = tool_call_constraint
if constraint_type == "structural_tag":
sampling_params[constraint_type] = convert_json_schema_to_str(
constraint_value.model_dump(by_alias=True)
)
else:
sampling_params[constraint_type] = constraint_value
return sampling_params
async def _handle_streaming_request(
self,
adapted_request: GenerateReqInput,
request: ChatCompletionRequest,
raw_request: Request,
) -> StreamingResponse:
"""Handle streaming chat completion request"""
async def generate_stream_resp():
parser_dict = {}
reasoning_parser_dict = {}
tool_call_first = True
is_firsts = {}
stream_buffers = {}
n_prev_tokens = {}
prompt_tokens = {}
completion_tokens = {}
cached_tokens = {}
try:
async for content in self.tokenizer_manager.generate_request(
adapted_request, raw_request
):
index = content.get("index", 0)
is_first = is_firsts.get(index, True)
stream_buffer = stream_buffers.get(index, "")
n_prev_token = n_prev_tokens.get(index, 0)
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
completion_tokens[index] = content["meta_info"]["completion_tokens"]
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
# Handle logprobs
choice_logprobs = None
if request.logprobs:
choice_logprobs = self._process_streaming_logprobs(
content, n_prev_token
)
n_prev_token = len(
content["meta_info"]["output_token_logprobs"]
)
finish_reason = content["meta_info"]["finish_reason"]
finish_reason_type = (
finish_reason["type"] if finish_reason else None
)
# First chunk with role
if is_first:
is_first = False
delta = DeltaMessage(role="assistant")
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=delta,
finish_reason=finish_reason_type,
matched_stop=(
finish_reason["matched"]
if finish_reason and "matched" in finish_reason
else None
),
logprobs=choice_logprobs,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=int(time.time()),
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n"
# Process content delta
delta = content["text"][len(stream_buffer) :]
new_stream_buffer = stream_buffer + delta
# Handle reasoning content
enable_thinking = getattr(request, "chat_template_kwargs", {}).get(
"enable_thinking", True
)
if (
self.tokenizer_manager.server_args.reasoning_parser
and request.separate_reasoning
and enable_thinking
):
reasoning_text, delta = self._process_reasoning_stream(
index, delta, reasoning_parser_dict, content, request
)
if reasoning_text:
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(reasoning_content=reasoning_text),
finish_reason=finish_reason_type,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=int(time.time()),
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n"
if not delta:
stream_buffers[index] = new_stream_buffer
is_firsts[index] = is_first
n_prev_tokens[index] = n_prev_token
continue
# Handle tool calls
if request.tool_choice != "none" and request.tools:
async for chunk in self._process_tool_call_stream(
index,
delta,
parser_dict,
content,
request,
finish_reason_type,
):
yield chunk
else:
# Regular content
if delta or not (
request.stream_options
and request.stream_options.include_usage
):
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(content=delta if delta else None),
finish_reason=(
None
if request.stream_options
and request.stream_options.include_usage
else finish_reason_type
),
matched_stop=(
finish_reason["matched"]
if finish_reason and "matched" in finish_reason
else None
),
logprobs=choice_logprobs,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=int(time.time()),
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n"
stream_buffers[index] = new_stream_buffer
is_firsts[index] = is_first
n_prev_tokens[index] = n_prev_token
# Final chunk with usage
if request.stream_options and request.stream_options.include_usage:
usage = self._calculate_streaming_usage_base(
prompt_tokens, completion_tokens, cached_tokens, request.n
)
else:
usage = None
final_chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=int(time.time()),
choices=[
ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(),
finish_reason=finish_reason_type,
)
],
model=request.model,
usage=usage,
)
yield f"data: {final_chunk.model_dump_json()}\n\n"
except Exception as e:
error = self.create_streaming_error_response(str(e))
yield f"data: {error}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
generate_stream_resp(),
media_type="text/event-stream",
background=self.tokenizer_manager.create_abort_task(adapted_request),
)
async def _handle_non_streaming_request(
self,
adapted_request: GenerateReqInput,
request: ChatCompletionRequest,
raw_request: Request,
) -> Union[ChatCompletionResponse, ErrorResponse]:
"""Handle non-streaming chat completion request"""
try:
ret = await self.tokenizer_manager.generate_request(
adapted_request, raw_request
).__anext__()
except ValueError as e:
return self.create_error_response(str(e))
if not isinstance(ret, list):
ret = [ret]
response = self._build_chat_response(
request,
ret,
int(time.time()),
cache_report=self.tokenizer_manager.server_args.enable_cache_report,
tool_call_parser=self.tokenizer_manager.server_args.tool_call_parser,
reasoning_parser=self.tokenizer_manager.server_args.reasoning_parser,
)
return response
def _build_chat_response(
self,
request: ChatCompletionRequest,
ret: List[Dict[str, Any]],
created: int,
cache_report: bool = False,
tool_call_parser: Optional[str] = None,
reasoning_parser: Optional[str] = None,
) -> ChatCompletionResponse:
"""Build chat completion response from generation results"""
choices = []
for idx, ret_item in enumerate(ret):
# Process logprobs
choice_logprobs = None
if request.logprobs:
choice_logprobs = self._process_response_logprobs(ret_item)
finish_reason = ret_item["meta_info"]["finish_reason"]
text = ret_item["text"]
# Handle reasoning content
reasoning_text = None
enable_thinking = getattr(request, "chat_template_kwargs", {}).get(
"enable_thinking", True
)
if reasoning_parser and request.separate_reasoning and enable_thinking:
try:
parser = ReasoningParser(
model_type=reasoning_parser, stream_reasoning=False
)
reasoning_text, text = parser.parse_non_stream(text)
except Exception as e:
logger.error(f"Reasoning parsing error: {e}")
return self.create_error_response(
"Failed to parse reasoning content",
err_type="InternalServerError",
status_code=500,
)
# Handle tool calls
tool_calls = None
if request.tool_choice != "none" and request.tools:
tool_calls, text, finish_reason = self._process_tool_calls(
text, request.tools, tool_call_parser, finish_reason
)
choice_data = ChatCompletionResponseChoice(
index=idx,
message=ChatMessage(
role="assistant",
content=text if text else None,
tool_calls=tool_calls,
reasoning_content=reasoning_text if reasoning_text else None,
),
logprobs=choice_logprobs,
finish_reason=finish_reason["type"] if finish_reason else None,
matched_stop=(
finish_reason["matched"]
if finish_reason and "matched" in finish_reason
else None
),
)
choices.append(choice_data)
# Calculate usage
usage = aggregate_token_usage(ret, request.n, cache_report)
return ChatCompletionResponse(
id=ret[0]["meta_info"]["id"],
created=created,
model=request.model,
choices=choices,
usage=usage,
)
def _process_logprobs_tokens(
self, logprobs: LogProbs, use_token_index: bool = False
) -> List[ChatCompletionTokenLogprob]:
"""Common helper to process logprobs tokens for both streaming and non-streaming
Args:
logprobs: LogProbs data from model
use_token_index: True for non-streaming (use token_idx), False for streaming (use index 0)
"""
token_logprobs = []
for token_idx, (token, logprob) in enumerate(
zip(logprobs.tokens, logprobs.token_logprobs)
):
token_bytes = list(token.encode("utf-8"))
top_logprobs = []
if logprobs.top_logprobs:
# - Non-streaming (use_token_index=True): uses token_idx for full data
# - Streaming (use_token_index=False): uses index 0 for pre-sliced data
top_logprobs_idx = token_idx if use_token_index else 0
for top_token, top_logprob in logprobs.top_logprobs[
top_logprobs_idx
].items():
top_token_bytes = list(top_token.encode("utf-8"))
top_logprobs.append(
TopLogprob(
token=top_token,
bytes=top_token_bytes,
logprob=top_logprob,
)
)
token_logprobs.append(
ChatCompletionTokenLogprob(
token=token,
bytes=token_bytes,
logprob=logprob,
top_logprobs=top_logprobs,
)
)
return token_logprobs
def _process_response_logprobs(self, ret_item: Dict[str, Any]) -> ChoiceLogprobs:
"""Process logprobs for non-streaming response"""
logprobs = to_openai_style_logprobs(
output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"],
output_top_logprobs=ret_item["meta_info"].get("output_top_logprobs", None),
)
token_logprobs = self._process_logprobs_tokens(logprobs, use_token_index=True)
return ChoiceLogprobs(content=token_logprobs)
def _process_tool_calls(
self,
text: str,
tools: List[Any],
tool_call_parser: Optional[str],
finish_reason: Dict[str, Any],
) -> tuple[Optional[List[ToolCall]], str, Dict[str, Any]]:
"""Process tool calls in the response"""
parser = FunctionCallParser(tools, tool_call_parser)
if parser.has_tool_call(text):
if finish_reason["type"] == "stop":
finish_reason["type"] = "tool_calls"
finish_reason["matched"] = None
try:
text, call_info_list = parser.parse_non_stream(text)
tool_calls = [
ToolCall(
id=f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}",
function=FunctionResponse(
name=call_info.name, arguments=call_info.parameters
),
)
for call_info in call_info_list
]
return tool_calls, text, finish_reason
except Exception as e:
logger.error(f"Tool call parsing error: {e}")
# Return error but don't fail the whole request
return None, text, finish_reason
return None, text, finish_reason
def _process_streaming_logprobs(
self, content: Dict[str, Any], n_prev_token: int
) -> ChoiceLogprobs:
"""Process logprobs for streaming response"""
logprobs = to_openai_style_logprobs(
output_token_logprobs=content["meta_info"]["output_token_logprobs"][
n_prev_token:
],
output_top_logprobs=content["meta_info"].get("output_top_logprobs", [])[
n_prev_token:
],
)
token_logprobs = self._process_logprobs_tokens(logprobs, use_token_index=False)
return ChoiceLogprobs(content=token_logprobs)
def _process_reasoning_stream(
self,
index: int,
delta: str,
reasoning_parser_dict: Dict[int, ReasoningParser],
content: Dict[str, Any],
request: ChatCompletionRequest,
) -> tuple[Optional[str], str]:
"""Process reasoning content in streaming response"""
if index not in reasoning_parser_dict:
reasoning_parser_dict[index] = ReasoningParser(
self.tokenizer_manager.server_args.reasoning_parser,
request.stream_reasoning,
)
reasoning_parser = reasoning_parser_dict[index]
return reasoning_parser.parse_stream_chunk(delta)
async def _process_tool_call_stream(
self,
index: int,
delta: str,
parser_dict: Dict[int, FunctionCallParser],
content: Dict[str, Any],
request: ChatCompletionRequest,
finish_reason_type: Optional[str],
):
"""Process tool calls in streaming response"""
if index not in parser_dict:
parser_dict[index] = FunctionCallParser(
tools=request.tools,
tool_call_parser=self.tokenizer_manager.server_args.tool_call_parser,
)
parser = parser_dict[index]
normal_text, calls = parser.parse_stream_chunk(delta)
# Yield normal text
if normal_text:
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(content=normal_text),
finish_reason=finish_reason_type,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=int(time.time()),
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n"
# Yield tool calls
for call_item in calls:
if finish_reason_type == "stop":
# Handle remaining arguments
latest_delta_len = 0
if isinstance(call_item.parameters, str):
latest_delta_len = len(call_item.parameters)
expected_call = json.dumps(
parser.detector.prev_tool_call_arr[index].get("arguments", {}),
ensure_ascii=False,
)
actual_call = parser.detector.streamed_args_for_tool[index]
if latest_delta_len > 0:
actual_call = actual_call[:-latest_delta_len]
remaining_call = expected_call.replace(actual_call, "", 1)
call_item.parameters = remaining_call
finish_reason_type = "tool_calls"
tool_call = ToolCall(
id=f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}",
index=call_item.tool_index,
function=FunctionResponse(
name=call_item.name,
arguments=call_item.parameters,
),
)
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(tool_calls=[tool_call]),
finish_reason=(
None
if request.stream_options and request.stream_options.include_usage
else finish_reason_type
),
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=int(time.time()),
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n"
import time
from typing import Any, Dict, List, Optional, Union
from fastapi import Request
from fastapi.responses import StreamingResponse
from sglang.srt.code_completion_parser import (
generate_completion_prompt_from_request,
is_completion_template_defined,
)
from sglang.srt.entrypoints.openai.protocol import (
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
ErrorResponse,
)
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
from sglang.srt.entrypoints.openai.utils import (
aggregate_token_usage,
to_openai_style_logprobs,
)
from sglang.srt.managers.io_struct import GenerateReqInput
class OpenAIServingCompletion(OpenAIServingBase):
"""Handler for completion requests"""
def _request_id_prefix(self) -> str:
return "cmpl-"
def _validate_request(self, request: CompletionRequest) -> Optional[str]:
"""Validate completion prompt format and content"""
if not (prompt := request.prompt):
return "Prompt cannot be None"
if isinstance(prompt, str):
if not prompt.strip():
return "Prompt cannot be empty or whitespace only"
elif isinstance(prompt, list):
if not prompt:
return "Prompt list cannot be empty"
# Check if it's a list of strings
if all(isinstance(item, str) for item in prompt):
for i, item in enumerate(prompt):
if not item.strip():
return f"Prompt at index {i} cannot be empty or whitespace only"
# Check if it's a list of token IDs (integers)
elif all(isinstance(item, int) for item in prompt):
if any(item < 0 for item in prompt):
return "Token IDs must be non-negative"
# Check if it's a list of lists (multiple token sequences)
elif all(isinstance(item, list) for item in prompt):
for i, item in enumerate(prompt):
if not item:
return f"Token sequence at index {i} cannot be empty"
if not all(isinstance(token, int) for token in item):
return f"Token sequence at index {i} must contain only integers"
if any(token < 0 for token in item):
return (
f"Token sequence at index {i} contains negative token IDs"
)
else:
return "Prompt must be string, list of strings, list of integers, or list of integer lists"
else:
return "Prompt must be string or list"
return None
def _convert_to_internal_request(
self,
all_requests: List[CompletionRequest],
request_ids: List[str],
) -> tuple[GenerateReqInput, Union[CompletionRequest, List[CompletionRequest]]]:
"""Convert OpenAI completion request to internal format"""
# Validate batch requests
if len(all_requests) > 1:
first_prompt_type = type(all_requests[0].prompt)
for request in all_requests:
assert (
type(request.prompt) is first_prompt_type
), "All prompts must be of the same type in file input settings"
if request.n > 1:
raise ValueError(
"Parallel sampling is not supported for completions from files"
)
prompts = []
sampling_params_list = []
return_logprobs = []
logprob_start_lens = []
top_logprobs_nums = []
lora_paths = []
for request in all_requests:
# Process prompt
prompt = request.prompt
if is_completion_template_defined():
prompt = generate_completion_prompt_from_request(request)
prompts.append(prompt)
lora_paths.append(request.lora_path)
# Set logprob start length based on echo and logprobs
if request.echo and request.logprobs:
current_logprob_start_len = 0
else:
current_logprob_start_len = -1
# Build sampling parameters
sampling_params = self._build_sampling_params(request)
sampling_params_list.append(sampling_params)
return_logprobs.append(request.logprobs is not None)
logprob_start_lens.append(current_logprob_start_len)
top_logprobs_nums.append(
request.logprobs if request.logprobs is not None else 0
)
# Handle single vs multiple requests
if len(all_requests) == 1:
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
prompt_kwargs = {"text": prompts[0]}
else:
prompt_kwargs = {"input_ids": prompts[0]}
sampling_params_list = sampling_params_list[0]
return_logprobs = return_logprobs[0]
logprob_start_lens = logprob_start_lens[0]
top_logprobs_nums = top_logprobs_nums[0]
lora_paths = lora_paths[0]
request_ids = request_ids[0]
else:
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
prompt_kwargs = {"text": prompts}
else:
prompt_kwargs = {"input_ids": prompts}
adapted_request = GenerateReqInput(
**prompt_kwargs,
sampling_params=sampling_params_list,
return_logprob=return_logprobs,
top_logprobs_num=top_logprobs_nums,
logprob_start_len=logprob_start_lens,
return_text_in_logprobs=True,
stream=all_requests[0].stream,
rid=request_ids,
lora_path=lora_paths,
bootstrap_host=all_requests[0].bootstrap_host,
bootstrap_port=all_requests[0].bootstrap_port,
bootstrap_room=all_requests[0].bootstrap_room,
)
return adapted_request, (
all_requests if len(all_requests) > 1 else all_requests[0]
)
def _build_sampling_params(self, request: CompletionRequest) -> Dict[str, Any]:
"""Build sampling parameters for the request"""
# Start with common parameters
sampling_params = {
"temperature": request.temperature,
"max_new_tokens": request.max_tokens,
"min_new_tokens": request.min_tokens,
"stop": request.stop,
"stop_token_ids": request.stop_token_ids,
"top_p": request.top_p,
"top_k": request.top_k,
"min_p": request.min_p,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
"repetition_penalty": request.repetition_penalty,
"regex": request.regex,
"json_schema": request.json_schema,
"ebnf": request.ebnf,
"n": request.n,
"no_stop_trim": request.no_stop_trim,
"ignore_eos": request.ignore_eos,
"skip_special_tokens": request.skip_special_tokens,
"logit_bias": request.logit_bias,
}
# No additional completion-specific parameters needed currently
# (json_schema is already handled in base method)
return sampling_params
async def _handle_streaming_request(
self,
adapted_request: GenerateReqInput,
request: CompletionRequest,
raw_request: Request,
) -> StreamingResponse:
"""Handle streaming completion request"""
created = int(time.time())
async def generate_stream_resp():
stream_buffers = {}
n_prev_tokens = {}
prompt_tokens = {}
completion_tokens = {}
cached_tokens = {}
try:
async for content in self.tokenizer_manager.generate_request(
adapted_request, raw_request
):
index = content.get("index", 0)
stream_buffer = stream_buffers.get(index, "")
n_prev_token = n_prev_tokens.get(index, 0)
text = content["text"]
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
completion_tokens[index] = content["meta_info"]["completion_tokens"]
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
# Handle echo for first chunk
if not stream_buffer: # The first chunk
if request.echo:
echo_text = self._get_echo_text(request, index)
text = echo_text + text
# Handle logprobs
logprobs = None
if request.logprobs is not None:
# The first chunk and echo is enabled.
if not stream_buffer and request.echo:
input_token_logprobs = content["meta_info"][
"input_token_logprobs"
]
input_top_logprobs = content["meta_info"][
"input_top_logprobs"
]
else:
input_token_logprobs = None
input_top_logprobs = None
logprobs = to_openai_style_logprobs(
input_token_logprobs=input_token_logprobs,
input_top_logprobs=input_top_logprobs,
output_token_logprobs=content["meta_info"][
"output_token_logprobs"
][n_prev_token:],
output_top_logprobs=content["meta_info"][
"output_top_logprobs"
][n_prev_token:],
)
n_prev_token = len(
content["meta_info"]["output_token_logprobs"]
)
# Generate delta
delta = text[len(stream_buffer) :]
stream_buffer = stream_buffer + delta
finish_reason = content["meta_info"]["finish_reason"]
choice_data = CompletionResponseStreamChoice(
index=index,
text=delta,
logprobs=logprobs,
finish_reason=finish_reason["type"] if finish_reason else None,
matched_stop=(
finish_reason["matched"]
if finish_reason and "matched" in finish_reason
else None
),
)
chunk = CompletionStreamResponse(
id=content["meta_info"]["id"],
created=created,
object="text_completion",
choices=[choice_data],
model=request.model,
)
stream_buffers[index] = stream_buffer
n_prev_tokens[index] = n_prev_token
yield f"data: {chunk.model_dump_json()}\n\n"
# Handle final usage chunk
if request.stream_options and request.stream_options.include_usage:
usage = self._calculate_streaming_usage_base(
prompt_tokens, completion_tokens, cached_tokens, request.n
)
final_usage_chunk = CompletionStreamResponse(
id=content["meta_info"]["id"],
created=created,
choices=[],
model=request.model,
usage=usage,
)
final_usage_data = final_usage_chunk.model_dump_json(
exclude_none=True
)
yield f"data: {final_usage_data}\n\n"
except Exception as e:
error = self.create_streaming_error_response(str(e))
yield f"data: {error}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
generate_stream_resp(),
media_type="text/event-stream",
background=self.tokenizer_manager.create_abort_task(adapted_request),
)
async def _handle_non_streaming_request(
self,
adapted_request: GenerateReqInput,
request: CompletionRequest,
raw_request: Request,
) -> Union[CompletionResponse, ErrorResponse]:
"""Handle non-streaming completion request"""
try:
generator = self.tokenizer_manager.generate_request(
adapted_request, raw_request
)
ret = await generator.__anext__()
except ValueError as e:
return self.create_error_response(str(e))
if not isinstance(ret, list):
ret = [ret]
response = self._build_completion_response(
request,
ret,
int(time.time()),
cache_report=self.tokenizer_manager.server_args.enable_cache_report,
)
return response
def _build_completion_response(
self,
request: CompletionRequest,
ret: List[Dict[str, Any]],
created: int,
cache_report: bool = False,
) -> CompletionResponse:
"""Build completion response from generation results"""
choices = []
echo = False
# Prepare echo prompts if needed
echo_prompts = []
if (not isinstance(request, list)) and request.echo:
echo_prompts = self._prepare_echo_prompts(request)
echo = True
for idx, ret_item in enumerate(ret):
text = ret_item["text"]
# Handle echo
if isinstance(request, list) and request[idx].echo:
echo = True
text = request[idx].prompt + text
elif echo and not isinstance(request, list):
prompt_index = idx // request.n
text = echo_prompts[prompt_index] + text
# Handle logprobs
logprobs = None
if isinstance(request, list) and request[idx].logprobs is not None:
logprobs = True
elif (not isinstance(request, list)) and request.logprobs is not None:
logprobs = True
if logprobs:
if echo:
input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"]
input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"]
else:
input_token_logprobs = None
input_top_logprobs = None
logprobs = to_openai_style_logprobs(
input_token_logprobs=input_token_logprobs,
input_top_logprobs=input_top_logprobs,
output_token_logprobs=ret_item["meta_info"][
"output_token_logprobs"
],
output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
)
finish_reason = ret_item["meta_info"]["finish_reason"]
choice_data = CompletionResponseChoice(
index=idx,
text=text,
logprobs=logprobs,
finish_reason=finish_reason["type"] if finish_reason else None,
matched_stop=(
finish_reason["matched"]
if finish_reason and "matched" in finish_reason
else None
),
)
choices.append(choice_data)
# Calculate usage
usage = aggregate_token_usage(ret, request.n, cache_report)
return CompletionResponse(
id=ret[0]["meta_info"]["id"],
model=request.model,
created=created,
choices=choices,
usage=usage,
)
def _get_echo_text(self, request: CompletionRequest, index: int) -> str:
"""Get echo text for streaming response"""
if isinstance(request.prompt, str):
# for the case of single str prompts
return request.prompt
elif isinstance(request.prompt, list):
if isinstance(request.prompt[0], str):
# for the case of multiple str prompts
return request.prompt[index // request.n]
elif isinstance(request.prompt[0], int):
# for the case of single token ids prompt
return self.tokenizer_manager.tokenizer.decode(
request.prompt, skip_special_tokens=True
)
elif isinstance(request.prompt[0], list) and isinstance(
request.prompt[0][0], int
):
# for the case of multiple token ids prompts
return self.tokenizer_manager.tokenizer.decode(
request.prompt[index // request.n],
skip_special_tokens=True,
)
return ""
def _prepare_echo_prompts(self, request: CompletionRequest) -> List[str]:
"""Prepare echo prompts for non-streaming response"""
# TODO: handle the case prompt is token ids
if isinstance(request.prompt, list) and isinstance(request.prompt[0], str):
# for the case of multiple str prompts
return request.prompt
elif isinstance(request.prompt, list) and isinstance(request.prompt[0], list):
# for the case of multiple token ids prompts
return [
self.tokenizer_manager.tokenizer.decode(
prompt, skip_special_tokens=True
)
for prompt in request.prompt
]
elif isinstance(request.prompt, list) and isinstance(request.prompt[0], int):
# for the case of single token ids prompt
return [
self.tokenizer_manager.tokenizer.decode(
request.prompt, skip_special_tokens=True
)
]
else:
# for the case of single str prompt
return [request.prompt]
from typing import Any, Dict, List, Optional, Union
from fastapi import Request
from sglang.srt.conversation import generate_embedding_convs
from sglang.srt.entrypoints.openai.protocol import (
EmbeddingObject,
EmbeddingRequest,
EmbeddingResponse,
ErrorResponse,
MultimodalEmbeddingInput,
UsageInfo,
)
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
from sglang.srt.managers.io_struct import EmbeddingReqInput
class OpenAIServingEmbedding(OpenAIServingBase):
"""Handler for embedding requests"""
def _request_id_prefix(self) -> str:
return "embd-"
def _validate_request(self, request: EmbeddingRequest) -> Optional[str]:
"""Validate that the input is not empty or whitespace only."""
if not (input := request.input):
return "Input cannot be empty"
# Handle single string
if isinstance(input, str):
if not input.strip():
return "Input cannot be empty or whitespace only"
return None
# Handle list inputs
if isinstance(input, list):
if len(input) == 0:
return "Input cannot be empty"
# Check first element to determine type
first_item = input[0]
if isinstance(first_item, str):
# List of strings
for i, item in enumerate(input):
if not isinstance(item, str):
return f"All items in input list must be strings"
if not item.strip():
return f"Input at index {i} cannot be empty or whitespace only"
elif isinstance(first_item, int):
# List of integers (token IDs)
for i, item in enumerate(input):
if not isinstance(item, int):
return f"All items in input list must be integers"
if item < 0:
return f"Token ID at index {i} must be non-negative"
elif isinstance(first_item, list):
# List of lists (multiple token sequences)
for i, item in enumerate(input):
if not isinstance(item, list):
return f"Input at index {i} must be a list"
if not item:
return f"Input at index {i} cannot be empty"
if not all(isinstance(token, int) for token in item):
return f"Input at index {i} must contain only integers"
if any(token < 0 for token in item):
return f"Input at index {i} contains negative token IDs"
# Note: MultimodalEmbeddingInput validation would be handled by Pydantic
return None
def _convert_to_internal_request(
self,
all_requests: List[EmbeddingRequest],
request_ids: List[str],
) -> tuple[EmbeddingReqInput, Union[EmbeddingRequest, List[EmbeddingRequest]]]:
"""Convert OpenAI embedding request to internal format"""
prompts = [request.input for request in all_requests]
# Handle single vs multiple requests
if len(all_requests) == 1:
prompt = prompts[0]
if isinstance(prompt, str):
# Single string input
prompt_kwargs = {"text": prompt}
elif isinstance(prompt, list):
if len(prompt) > 0 and isinstance(prompt[0], str):
# List of strings
prompt_kwargs = {"text": prompt}
elif len(prompt) > 0 and isinstance(
prompt[0], MultimodalEmbeddingInput
):
# Handle multimodal embedding inputs
texts = []
images = []
for item in prompt:
# Use padding for text if None - this could be improved
texts.append(item.text if item.text is not None else "padding")
images.append(item.image if item.image is not None else None)
generate_prompts = []
# Check if we have a chat template for multimodal embeddings
# This would need to be passed in from the server configuration
chat_template_name = getattr(
self.tokenizer_manager, "chat_template_name", None
)
if chat_template_name is not None:
convs = generate_embedding_convs(
texts, images, chat_template_name
)
for conv in convs:
generate_prompts.append(conv.get_prompt())
else:
generate_prompts = texts
if len(generate_prompts) == 1:
prompt_kwargs = {
"text": generate_prompts[0],
"image_data": images[0],
}
else:
prompt_kwargs = {
"text": generate_prompts,
"image_data": images,
}
else:
# List of integers (token IDs) or empty list
prompt_kwargs = {"input_ids": prompt}
else:
# Other types (should not happen but handle gracefully)
prompt_kwargs = {"input_ids": prompt}
# Use the passed request_ids for single request
final_request_id = request_ids[0] if len(all_requests) == 1 else request_ids
else:
# Handle batch requests
if len(prompts) > 0:
# Validate that all prompts have the same type
first_prompt = prompts[0]
first_type = type(first_prompt)
for i, prompt in enumerate(prompts[1:], 1):
if type(prompt) != first_type:
raise AssertionError(
f"All prompts in batch must have the same type, but prompt at index {i} has different type"
)
if isinstance(first_prompt, str):
# Batch of strings
prompt_kwargs = {"text": prompts}
elif isinstance(first_prompt, list):
if len(first_prompt) > 0 and isinstance(first_prompt[0], str):
# Batch of lists of strings
prompt_kwargs = {"text": prompts}
elif len(first_prompt) > 0 and isinstance(
first_prompt[0], MultimodalEmbeddingInput
):
# Handle multimodal batch requests
raise NotImplementedError(
"Multiple requests with multimodal inputs are not supported yet"
)
else:
# Batch of token ID lists
prompt_kwargs = {"input_ids": prompts}
else:
# Other types
prompt_kwargs = {"input_ids": prompts}
else:
prompt_kwargs = {"input_ids": prompts}
# Use the passed request_ids for batch requests
final_request_id = request_ids
adapted_request = EmbeddingReqInput(
rid=final_request_id,
**prompt_kwargs,
)
return adapted_request, (
all_requests[0] if len(all_requests) == 1 else all_requests
)
async def _handle_non_streaming_request(
self,
adapted_request: EmbeddingReqInput,
request: EmbeddingRequest,
raw_request: Request,
) -> Union[EmbeddingResponse, ErrorResponse]:
"""Handle the embedding request"""
try:
ret = await self.tokenizer_manager.generate_request(
adapted_request, raw_request
).__anext__()
except ValueError as e:
return self.create_error_response(str(e))
if not isinstance(ret, list):
ret = [ret]
response = self._build_embedding_response(
ret, self.tokenizer_manager.model_path
)
return response
def _build_embedding_response(
self, ret: List[Dict[str, Any]], model_path: str
) -> EmbeddingResponse:
"""Build the embedding response"""
embedding_objects = []
prompt_tokens = 0
for idx, ret_item in enumerate(ret):
embedding_objects.append(
EmbeddingObject(
embedding=ret_item["embedding"],
index=idx,
)
)
# Handle missing prompt_tokens gracefully
meta_info = ret_item.get("meta_info", {})
prompt_tokens += meta_info.get("prompt_tokens", 0)
return EmbeddingResponse(
data=embedding_objects,
model=model_path,
usage=UsageInfo(
prompt_tokens=prompt_tokens,
total_tokens=prompt_tokens,
),
)
import logging
from typing import Any, Dict, List, Optional
import jinja2.nodes
import transformers.utils.chat_template_utils as hf_chat_utils
from sglang.srt.entrypoints.openai.protocol import LogProbs, UsageInfo
logger = logging.getLogger(__name__)
# ============================================================================
# JINJA TEMPLATE CONTENT FORMAT DETECTION
# ============================================================================
#
# This adapts vLLM's approach for detecting chat template content format:
# https://github.com/vllm-project/vllm/blob/02f0c7b220422792f5e53de2a7d51d2d3ff2df28/vllm/entrypoints/chat_utils.py#L296-L313
# - Analyzes Jinja template AST to detect content iteration patterns
# - 'openai' format: templates with {%- for content in message['content'] -%} loops
# - 'string' format: templates that expect simple string content
# - Processes content accordingly to match template expectations
def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
"""Check if node is a variable access like {{ varname }}"""
if isinstance(node, jinja2.nodes.Name):
return node.ctx == "load" and node.name == varname
return False
def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
"""Check if node is an attribute access like {{ varname['key'] }} or {{ varname.key }}"""
if isinstance(node, jinja2.nodes.Getitem):
return (
_is_var_access(node.node, varname)
and isinstance(node.arg, jinja2.nodes.Const)
and node.arg.value == key
)
if isinstance(node, jinja2.nodes.Getattr):
return _is_var_access(node.node, varname) and node.attr == key
return False
def _is_var_or_elems_access(
node: jinja2.nodes.Node,
varname: str,
key: str = None,
) -> bool:
"""Check if node accesses varname or varname[key] with filters/tests"""
if isinstance(node, jinja2.nodes.Filter):
return node.node is not None and _is_var_or_elems_access(
node.node, varname, key
)
if isinstance(node, jinja2.nodes.Test):
return _is_var_or_elems_access(node.node, varname, key)
if isinstance(node, jinja2.nodes.Getitem) and isinstance(
node.arg, jinja2.nodes.Slice
):
return _is_var_or_elems_access(node.node, varname, key)
return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname)
def _try_extract_ast(chat_template: str):
"""Try to parse the Jinja template into an AST"""
try:
jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
return jinja_compiled.environment.parse(chat_template)
except Exception as e:
logger.debug(f"Error when compiling Jinja template: {e}")
return None
def detect_template_content_format(chat_template: str) -> str:
"""
Detect whether a chat template expects 'string' or 'openai' content format.
- 'string': content is a simple string (like DeepSeek templates)
- 'openai': content is a list of structured dicts (like Llama4 templates)
Detection logic:
- If template has loops like {%- for content in message['content'] -%} → 'openai'
- Otherwise → 'string'
"""
jinja_ast = _try_extract_ast(chat_template)
if jinja_ast is None:
return "string"
try:
# Look for patterns like: {%- for content in message['content'] -%}
for loop_ast in jinja_ast.find_all(jinja2.nodes.For):
loop_iter = loop_ast.iter
# Check if iterating over message['content'] or similar
if _is_var_or_elems_access(loop_iter, "message", "content"):
return "openai" # Found content iteration → openai format
return "string" # No content loops found → string format
except Exception as e:
logger.debug(f"Error when parsing AST of Jinja template: {e}")
return "string"
def process_content_for_template_format(
msg_dict: dict,
content_format: str,
image_data: list,
audio_data: list,
modalities: list,
) -> dict:
"""
Process message content based on detected template format.
Args:
msg_dict: Message dictionary with content
content_format: 'string' or 'openai' (detected via AST analysis)
image_data: List to append extracted image URLs
audio_data: List to append extracted audio URLs
modalities: List to append modalities
Returns:
Processed message dictionary
"""
if not isinstance(msg_dict.get("content"), list):
# Already a string or None, no processing needed
return {k: v for k, v in msg_dict.items() if v is not None}
if content_format == "openai":
# OpenAI format: preserve structured content list, normalize types
processed_content_parts = []
for chunk in msg_dict["content"]:
if isinstance(chunk, dict):
chunk_type = chunk.get("type")
if chunk_type == "image_url":
image_data.append(chunk["image_url"]["url"])
if chunk.get("modalities"):
modalities.append(chunk.get("modalities"))
# Normalize to simple 'image' type for template compatibility
processed_content_parts.append({"type": "image"})
elif chunk_type == "audio_url":
audio_data.append(chunk["audio_url"]["url"])
# Normalize to simple 'audio' type
processed_content_parts.append({"type": "audio"})
else:
# Keep other content as-is (text, etc.)
processed_content_parts.append(chunk)
new_msg = {
k: v for k, v in msg_dict.items() if v is not None and k != "content"
}
new_msg["content"] = processed_content_parts
return new_msg
else: # content_format == "string"
# String format: flatten to text only (for templates like DeepSeek)
text_parts = []
for chunk in msg_dict["content"]:
if isinstance(chunk, dict) and chunk.get("type") == "text":
text_parts.append(chunk["text"])
# Note: For string format, we ignore images/audio since the template
# doesn't expect structured content - multimodal placeholders would
# need to be inserted differently
new_msg = msg_dict.copy()
new_msg["content"] = " ".join(text_parts) if text_parts else ""
new_msg = {k: v for k, v in new_msg.items() if v is not None}
return new_msg
def calculate_token_usage(
prompt_tokens: int,
completion_tokens: int,
cached_tokens: Optional[Dict[str, int]] = None,
) -> UsageInfo:
"""Calculate token usage information"""
return UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
prompt_tokens_details=cached_tokens,
)
def aggregate_token_usage(
responses: List[Dict[str, Any]],
n_choices: int = 1,
enable_cache_report: bool = False,
) -> UsageInfo:
"""Aggregate token usage from multiple responses
Args:
responses: List of response dictionaries with meta_info
n_choices: Number of choices per request (for prompt token counting)
enable_cache_report: Whether to include cached token details
Returns:
Aggregated UsageInfo
"""
# Sum completion tokens from all responses
completion_tokens = sum(
response["meta_info"]["completion_tokens"] for response in responses
)
# For prompt tokens, only count every n_choices-th response to avoid double counting
prompt_tokens = sum(
responses[i]["meta_info"]["prompt_tokens"]
for i in range(0, len(responses), n_choices)
)
# Handle cached tokens if cache reporting is enabled
cached_tokens_details = None
if enable_cache_report:
cached_tokens_sum = sum(
response["meta_info"].get("cached_tokens", 0) for response in responses
)
if cached_tokens_sum > 0:
cached_tokens_details = {"cached_tokens": cached_tokens_sum}
return calculate_token_usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
cached_tokens=cached_tokens_details,
)
def to_openai_style_logprobs(
input_token_logprobs=None,
output_token_logprobs=None,
input_top_logprobs=None,
output_top_logprobs=None,
):
ret_logprobs = LogProbs()
def append_token_logprobs(token_logprobs):
for logprob, _, token_text in token_logprobs:
ret_logprobs.tokens.append(token_text)
ret_logprobs.token_logprobs.append(logprob)
# Not supported yet
ret_logprobs.text_offset.append(-1)
def append_top_logprobs(top_logprobs):
for tokens in top_logprobs:
if tokens is not None:
ret_logprobs.top_logprobs.append(
{token[2]: token[0] for token in tokens}
)
else:
ret_logprobs.top_logprobs.append(None)
if input_token_logprobs is not None:
append_token_logprobs(input_token_logprobs)
if output_token_logprobs is not None:
append_token_logprobs(output_token_logprobs)
if input_top_logprobs is not None:
append_top_logprobs(input_top_logprobs)
if output_top_logprobs is not None:
append_top_logprobs(output_top_logprobs)
return ret_logprobs
[pytest]
asyncio_mode = auto
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for OpenAI API protocol models"""
import json
import time
from typing import Dict, List, Optional
import pytest
from pydantic import ValidationError
from sglang.srt.entrypoints.openai.protocol import (
BatchRequest,
BatchResponse,
ChatCompletionMessageContentImagePart,
ChatCompletionMessageContentTextPart,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatCompletionTokenLogprob,
ChatMessage,
ChoiceLogprobs,
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
DeltaMessage,
EmbeddingObject,
EmbeddingRequest,
EmbeddingResponse,
ErrorResponse,
FileDeleteResponse,
FileRequest,
FileResponse,
Function,
FunctionResponse,
JsonSchemaResponseFormat,
LogProbs,
ModelCard,
ModelList,
MultimodalEmbeddingInput,
ResponseFormat,
ScoringRequest,
ScoringResponse,
StreamOptions,
StructuralTagResponseFormat,
Tool,
ToolCall,
ToolChoice,
TopLogprob,
UsageInfo,
)
class TestModelCard:
"""Test ModelCard protocol model"""
def test_basic_model_card_creation(self):
"""Test basic model card creation with required fields"""
card = ModelCard(id="test-model")
assert card.id == "test-model"
assert card.object == "model"
assert card.owned_by == "sglang"
assert isinstance(card.created, int)
assert card.root is None
assert card.max_model_len is None
def test_model_card_with_optional_fields(self):
"""Test model card with optional fields"""
card = ModelCard(
id="test-model",
root="/path/to/model",
max_model_len=2048,
created=1234567890,
)
assert card.id == "test-model"
assert card.root == "/path/to/model"
assert card.max_model_len == 2048
assert card.created == 1234567890
def test_model_card_serialization(self):
"""Test model card JSON serialization"""
card = ModelCard(id="test-model", max_model_len=4096)
data = card.model_dump()
assert data["id"] == "test-model"
assert data["object"] == "model"
assert data["max_model_len"] == 4096
class TestModelList:
"""Test ModelList protocol model"""
def test_empty_model_list(self):
"""Test empty model list creation"""
model_list = ModelList()
assert model_list.object == "list"
assert len(model_list.data) == 0
def test_model_list_with_cards(self):
"""Test model list with model cards"""
cards = [
ModelCard(id="model-1"),
ModelCard(id="model-2", max_model_len=2048),
]
model_list = ModelList(data=cards)
assert len(model_list.data) == 2
assert model_list.data[0].id == "model-1"
assert model_list.data[1].id == "model-2"
class TestErrorResponse:
"""Test ErrorResponse protocol model"""
def test_basic_error_response(self):
"""Test basic error response creation"""
error = ErrorResponse(
message="Invalid request", type="BadRequestError", code=400
)
assert error.object == "error"
assert error.message == "Invalid request"
assert error.type == "BadRequestError"
assert error.code == 400
assert error.param is None
def test_error_response_with_param(self):
"""Test error response with parameter"""
error = ErrorResponse(
message="Invalid temperature",
type="ValidationError",
code=422,
param="temperature",
)
assert error.param == "temperature"
class TestUsageInfo:
"""Test UsageInfo protocol model"""
def test_basic_usage_info(self):
"""Test basic usage info creation"""
usage = UsageInfo(prompt_tokens=10, completion_tokens=20, total_tokens=30)
assert usage.prompt_tokens == 10
assert usage.completion_tokens == 20
assert usage.total_tokens == 30
assert usage.prompt_tokens_details is None
def test_usage_info_with_cache_details(self):
"""Test usage info with cache details"""
usage = UsageInfo(
prompt_tokens=10,
completion_tokens=20,
total_tokens=30,
prompt_tokens_details={"cached_tokens": 5},
)
assert usage.prompt_tokens_details == {"cached_tokens": 5}
class TestCompletionRequest:
"""Test CompletionRequest protocol model"""
def test_basic_completion_request(self):
"""Test basic completion request"""
request = CompletionRequest(model="test-model", prompt="Hello world")
assert request.model == "test-model"
assert request.prompt == "Hello world"
assert request.max_tokens == 16 # default
assert request.temperature == 1.0 # default
assert request.n == 1 # default
assert not request.stream # default
assert not request.echo # default
def test_completion_request_with_options(self):
"""Test completion request with various options"""
request = CompletionRequest(
model="test-model",
prompt=["Hello", "world"],
max_tokens=100,
temperature=0.7,
top_p=0.9,
n=2,
stream=True,
echo=True,
stop=[".", "!"],
logprobs=5,
)
assert request.prompt == ["Hello", "world"]
assert request.max_tokens == 100
assert request.temperature == 0.7
assert request.top_p == 0.9
assert request.n == 2
assert request.stream
assert request.echo
assert request.stop == [".", "!"]
assert request.logprobs == 5
def test_completion_request_sglang_extensions(self):
"""Test completion request with SGLang-specific extensions"""
request = CompletionRequest(
model="test-model",
prompt="Hello",
top_k=50,
min_p=0.1,
repetition_penalty=1.1,
regex=r"\d+",
json_schema='{"type": "object"}',
lora_path="/path/to/lora",
)
assert request.top_k == 50
assert request.min_p == 0.1
assert request.repetition_penalty == 1.1
assert request.regex == r"\d+"
assert request.json_schema == '{"type": "object"}'
assert request.lora_path == "/path/to/lora"
def test_completion_request_validation_errors(self):
"""Test completion request validation errors"""
with pytest.raises(ValidationError):
CompletionRequest() # missing required fields
with pytest.raises(ValidationError):
CompletionRequest(model="test-model") # missing prompt
class TestCompletionResponse:
"""Test CompletionResponse protocol model"""
def test_basic_completion_response(self):
"""Test basic completion response"""
choice = CompletionResponseChoice(
index=0, text="Hello world!", finish_reason="stop"
)
usage = UsageInfo(prompt_tokens=2, completion_tokens=3, total_tokens=5)
response = CompletionResponse(
id="test-id", model="test-model", choices=[choice], usage=usage
)
assert response.id == "test-id"
assert response.object == "text_completion"
assert response.model == "test-model"
assert len(response.choices) == 1
assert response.choices[0].text == "Hello world!"
assert response.usage.total_tokens == 5
class TestChatCompletionRequest:
"""Test ChatCompletionRequest protocol model"""
def test_basic_chat_completion_request(self):
"""Test basic chat completion request"""
messages = [{"role": "user", "content": "Hello"}]
request = ChatCompletionRequest(model="test-model", messages=messages)
assert request.model == "test-model"
assert len(request.messages) == 1
assert request.messages[0].role == "user"
assert request.messages[0].content == "Hello"
assert request.temperature == 0.7 # default
assert not request.stream # default
assert request.tool_choice == "none" # default when no tools
def test_chat_completion_with_multimodal_content(self):
"""Test chat completion with multimodal content"""
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {"url": "..."},
},
],
}
]
request = ChatCompletionRequest(model="test-model", messages=messages)
assert len(request.messages[0].content) == 2
assert request.messages[0].content[0].type == "text"
assert request.messages[0].content[1].type == "image_url"
def test_chat_completion_with_tools(self):
"""Test chat completion with tools"""
messages = [{"role": "user", "content": "What's the weather?"}]
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather information",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
},
},
}
]
request = ChatCompletionRequest(
model="test-model", messages=messages, tools=tools
)
assert len(request.tools) == 1
assert request.tools[0].function.name == "get_weather"
assert request.tool_choice == "auto" # default when tools present
def test_chat_completion_tool_choice_validation(self):
"""Test tool choice validation logic"""
messages = [{"role": "user", "content": "Hello"}]
# No tools, tool_choice should default to "none"
request1 = ChatCompletionRequest(model="test-model", messages=messages)
assert request1.tool_choice == "none"
# With tools, tool_choice should default to "auto"
tools = [
{
"type": "function",
"function": {"name": "test_func", "description": "Test function"},
}
]
request2 = ChatCompletionRequest(
model="test-model", messages=messages, tools=tools
)
assert request2.tool_choice == "auto"
def test_chat_completion_sglang_extensions(self):
"""Test chat completion with SGLang extensions"""
messages = [{"role": "user", "content": "Hello"}]
request = ChatCompletionRequest(
model="test-model",
messages=messages,
top_k=40,
min_p=0.05,
separate_reasoning=False,
stream_reasoning=False,
chat_template_kwargs={"custom_param": "value"},
)
assert request.top_k == 40
assert request.min_p == 0.05
assert not request.separate_reasoning
assert not request.stream_reasoning
assert request.chat_template_kwargs == {"custom_param": "value"}
class TestChatCompletionResponse:
"""Test ChatCompletionResponse protocol model"""
def test_basic_chat_completion_response(self):
"""Test basic chat completion response"""
message = ChatMessage(role="assistant", content="Hello there!")
choice = ChatCompletionResponseChoice(
index=0, message=message, finish_reason="stop"
)
usage = UsageInfo(prompt_tokens=2, completion_tokens=3, total_tokens=5)
response = ChatCompletionResponse(
id="test-id", model="test-model", choices=[choice], usage=usage
)
assert response.id == "test-id"
assert response.object == "chat.completion"
assert response.model == "test-model"
assert len(response.choices) == 1
assert response.choices[0].message.content == "Hello there!"
def test_chat_completion_response_with_tool_calls(self):
"""Test chat completion response with tool calls"""
tool_call = ToolCall(
id="call_123",
function=FunctionResponse(
name="get_weather", arguments='{"location": "San Francisco"}'
),
)
message = ChatMessage(role="assistant", content=None, tool_calls=[tool_call])
choice = ChatCompletionResponseChoice(
index=0, message=message, finish_reason="tool_calls"
)
usage = UsageInfo(prompt_tokens=10, completion_tokens=5, total_tokens=15)
response = ChatCompletionResponse(
id="test-id", model="test-model", choices=[choice], usage=usage
)
assert response.choices[0].message.tool_calls[0].function.name == "get_weather"
assert response.choices[0].finish_reason == "tool_calls"
class TestEmbeddingRequest:
"""Test EmbeddingRequest protocol model"""
def test_basic_embedding_request(self):
"""Test basic embedding request"""
request = EmbeddingRequest(model="test-model", input="Hello world")
assert request.model == "test-model"
assert request.input == "Hello world"
assert request.encoding_format == "float" # default
assert request.dimensions is None # default
def test_embedding_request_with_list_input(self):
"""Test embedding request with list input"""
request = EmbeddingRequest(
model="test-model", input=["Hello", "world"], dimensions=512
)
assert request.input == ["Hello", "world"]
assert request.dimensions == 512
def test_multimodal_embedding_request(self):
"""Test multimodal embedding request"""
multimodal_input = [
MultimodalEmbeddingInput(text="Hello", image="base64_image_data"),
MultimodalEmbeddingInput(text="World", image=None),
]
request = EmbeddingRequest(model="test-model", input=multimodal_input)
assert len(request.input) == 2
assert request.input[0].text == "Hello"
assert request.input[0].image == "base64_image_data"
assert request.input[1].text == "World"
assert request.input[1].image is None
class TestEmbeddingResponse:
"""Test EmbeddingResponse protocol model"""
def test_basic_embedding_response(self):
"""Test basic embedding response"""
embedding_obj = EmbeddingObject(embedding=[0.1, 0.2, 0.3], index=0)
usage = UsageInfo(prompt_tokens=3, total_tokens=3)
response = EmbeddingResponse(
data=[embedding_obj], model="test-model", usage=usage
)
assert response.object == "list"
assert len(response.data) == 1
assert response.data[0].embedding == [0.1, 0.2, 0.3]
assert response.data[0].index == 0
assert response.usage.prompt_tokens == 3
class TestScoringRequest:
"""Test ScoringRequest protocol model"""
def test_basic_scoring_request(self):
"""Test basic scoring request"""
request = ScoringRequest(
model="test-model", query="Hello", items=["World", "Earth"]
)
assert request.model == "test-model"
assert request.query == "Hello"
assert request.items == ["World", "Earth"]
assert not request.apply_softmax # default
assert not request.item_first # default
def test_scoring_request_with_token_ids(self):
"""Test scoring request with token IDs"""
request = ScoringRequest(
model="test-model",
query=[1, 2, 3],
items=[[4, 5], [6, 7]],
label_token_ids=[8, 9],
apply_softmax=True,
item_first=True,
)
assert request.query == [1, 2, 3]
assert request.items == [[4, 5], [6, 7]]
assert request.label_token_ids == [8, 9]
assert request.apply_softmax
assert request.item_first
class TestScoringResponse:
"""Test ScoringResponse protocol model"""
def test_basic_scoring_response(self):
"""Test basic scoring response"""
response = ScoringResponse(scores=[[0.1, 0.9], [0.3, 0.7]], model="test-model")
assert response.object == "scoring"
assert response.scores == [[0.1, 0.9], [0.3, 0.7]]
assert response.model == "test-model"
assert response.usage is None # default
class TestFileOperations:
"""Test file operation protocol models"""
def test_file_request(self):
"""Test file request model"""
file_data = b"test file content"
request = FileRequest(file=file_data, purpose="batch")
assert request.file == file_data
assert request.purpose == "batch"
def test_file_response(self):
"""Test file response model"""
response = FileResponse(
id="file-123",
bytes=1024,
created_at=1234567890,
filename="test.jsonl",
purpose="batch",
)
assert response.id == "file-123"
assert response.object == "file"
assert response.bytes == 1024
assert response.filename == "test.jsonl"
def test_file_delete_response(self):
"""Test file delete response model"""
response = FileDeleteResponse(id="file-123", deleted=True)
assert response.id == "file-123"
assert response.object == "file"
assert response.deleted
class TestBatchOperations:
"""Test batch operation protocol models"""
def test_batch_request(self):
"""Test batch request model"""
request = BatchRequest(
input_file_id="file-123",
endpoint="/v1/chat/completions",
completion_window="24h",
metadata={"custom": "value"},
)
assert request.input_file_id == "file-123"
assert request.endpoint == "/v1/chat/completions"
assert request.completion_window == "24h"
assert request.metadata == {"custom": "value"}
def test_batch_response(self):
"""Test batch response model"""
response = BatchResponse(
id="batch-123",
endpoint="/v1/chat/completions",
input_file_id="file-123",
completion_window="24h",
created_at=1234567890,
)
assert response.id == "batch-123"
assert response.object == "batch"
assert response.status == "validating" # default
assert response.endpoint == "/v1/chat/completions"
class TestResponseFormats:
"""Test response format protocol models"""
def test_basic_response_format(self):
"""Test basic response format"""
format_obj = ResponseFormat(type="json_object")
assert format_obj.type == "json_object"
assert format_obj.json_schema is None
def test_json_schema_response_format(self):
"""Test JSON schema response format"""
schema = {"type": "object", "properties": {"name": {"type": "string"}}}
json_schema = JsonSchemaResponseFormat(
name="person_schema", description="Person schema", schema=schema
)
format_obj = ResponseFormat(type="json_schema", json_schema=json_schema)
assert format_obj.type == "json_schema"
assert format_obj.json_schema.name == "person_schema"
assert format_obj.json_schema.schema_ == schema
def test_structural_tag_response_format(self):
"""Test structural tag response format"""
structures = [
{
"begin": "<thinking>",
"schema_": {"type": "string"},
"end": "</thinking>",
}
]
format_obj = StructuralTagResponseFormat(
type="structural_tag", structures=structures, triggers=["think"]
)
assert format_obj.type == "structural_tag"
assert len(format_obj.structures) == 1
assert format_obj.triggers == ["think"]
class TestLogProbs:
"""Test LogProbs protocol models"""
def test_basic_logprobs(self):
"""Test basic LogProbs model"""
logprobs = LogProbs(
text_offset=[0, 5, 11],
token_logprobs=[-0.1, -0.2, -0.3],
tokens=["Hello", " ", "world"],
top_logprobs=[{"Hello": -0.1}, {" ": -0.2}, {"world": -0.3}],
)
assert len(logprobs.tokens) == 3
assert logprobs.tokens == ["Hello", " ", "world"]
assert logprobs.token_logprobs == [-0.1, -0.2, -0.3]
def test_choice_logprobs(self):
"""Test ChoiceLogprobs model"""
token_logprob = ChatCompletionTokenLogprob(
token="Hello",
bytes=[72, 101, 108, 108, 111],
logprob=-0.1,
top_logprobs=[
TopLogprob(token="Hello", bytes=[72, 101, 108, 108, 111], logprob=-0.1)
],
)
choice_logprobs = ChoiceLogprobs(content=[token_logprob])
assert len(choice_logprobs.content) == 1
assert choice_logprobs.content[0].token == "Hello"
class TestStreamingModels:
"""Test streaming response models"""
def test_stream_options(self):
"""Test StreamOptions model"""
options = StreamOptions(include_usage=True)
assert options.include_usage
def test_chat_completion_stream_response(self):
"""Test ChatCompletionStreamResponse model"""
delta = DeltaMessage(role="assistant", content="Hello")
choice = ChatCompletionResponseStreamChoice(index=0, delta=delta)
response = ChatCompletionStreamResponse(
id="test-id", model="test-model", choices=[choice]
)
assert response.object == "chat.completion.chunk"
assert response.choices[0].delta.content == "Hello"
class TestValidationEdgeCases:
"""Test edge cases and validation scenarios"""
def test_empty_messages_validation(self):
"""Test validation with empty messages"""
with pytest.raises(ValidationError):
ChatCompletionRequest(model="test-model", messages=[])
def test_invalid_tool_choice_type(self):
"""Test invalid tool choice type"""
messages = [{"role": "user", "content": "Hello"}]
with pytest.raises(ValidationError):
ChatCompletionRequest(
model="test-model", messages=messages, tool_choice=123
)
def test_negative_token_limits(self):
"""Test negative token limits"""
with pytest.raises(ValidationError):
CompletionRequest(model="test-model", prompt="Hello", max_tokens=-1)
def test_invalid_temperature_range(self):
"""Test invalid temperature values"""
# Note: The current protocol doesn't enforce temperature range,
# but this test documents expected behavior
request = CompletionRequest(model="test-model", prompt="Hello", temperature=5.0)
assert request.temperature == 5.0 # Currently allowed
def test_model_serialization_roundtrip(self):
"""Test that models can be serialized and deserialized"""
original_request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Hello"}],
temperature=0.7,
max_tokens=100,
)
# Serialize to dict
data = original_request.model_dump()
# Deserialize back
restored_request = ChatCompletionRequest(**data)
assert restored_request.model == original_request.model
assert restored_request.temperature == original_request.temperature
assert restored_request.max_tokens == original_request.max_tokens
assert len(restored_request.messages) == len(original_request.messages)
if __name__ == "__main__":
pytest.main([__file__])
"""
Unit tests for the OpenAIServingChat class from serving_chat.py.
These tests ensure that the refactored implementation maintains compatibility
with the original adapter.py functionality.
"""
import uuid
from unittest.mock import Mock, patch
import pytest
from fastapi import Request
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest, ErrorResponse
from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
from sglang.srt.managers.io_struct import GenerateReqInput
# Mock TokenizerManager since it may not be directly importable in tests
class MockTokenizerManager:
def __init__(self):
self.model_config = Mock()
self.model_config.is_multimodal = False
self.server_args = Mock()
self.server_args.enable_cache_report = False
self.server_args.tool_call_parser = "hermes"
self.server_args.reasoning_parser = None
self.chat_template_name = "llama-3"
# Mock tokenizer
self.tokenizer = Mock()
self.tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5])
self.tokenizer.decode = Mock(return_value="Test response")
self.tokenizer.chat_template = None
self.tokenizer.bos_token_id = 1
# Mock generate_request method
async def mock_generate():
yield {
"text": "Test response",
"meta_info": {
"id": f"chatcmpl-{uuid.uuid4()}",
"prompt_tokens": 10,
"completion_tokens": 5,
"cached_tokens": 0,
"finish_reason": {"type": "stop", "matched": None},
"output_token_logprobs": [(0.1, 1, "Test"), (0.2, 2, "response")],
"output_top_logprobs": None,
},
"index": 0,
}
self.generate_request = Mock(return_value=mock_generate())
self.create_abort_task = Mock(return_value=None)
@pytest.fixture
def mock_tokenizer_manager():
"""Create a mock tokenizer manager for testing."""
return MockTokenizerManager()
@pytest.fixture
def serving_chat(mock_tokenizer_manager):
"""Create a OpenAIServingChat instance for testing."""
return OpenAIServingChat(mock_tokenizer_manager)
@pytest.fixture
def mock_request():
"""Create a mock FastAPI request."""
request = Mock(spec=Request)
request.headers = {}
return request
@pytest.fixture
def basic_chat_request():
"""Create a basic chat completion request."""
return ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7,
max_tokens=100,
stream=False,
)
@pytest.fixture
def streaming_chat_request():
"""Create a streaming chat completion request."""
return ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7,
max_tokens=100,
stream=True,
)
class TestOpenAIServingChatConversion:
"""Test request conversion methods."""
def test_convert_to_internal_request_single(
self, serving_chat, basic_chat_request, mock_tokenizer_manager
):
"""Test converting single request to internal format."""
with patch(
"sglang.srt.entrypoints.openai.serving_chat.generate_chat_conv"
) as mock_conv:
mock_conv_instance = Mock()
mock_conv_instance.get_prompt.return_value = "Test prompt"
mock_conv_instance.image_data = None
mock_conv_instance.audio_data = None
mock_conv_instance.modalities = []
mock_conv_instance.stop_str = ["</s>"]
mock_conv.return_value = mock_conv_instance
# Mock the _process_messages method to return expected values
with patch.object(serving_chat, "_process_messages") as mock_process:
mock_process.return_value = (
"Test prompt",
[1, 2, 3],
None,
None,
[],
["</s>"],
None, # tool_call_constraint
)
adapted_request, processed_request = (
serving_chat._convert_to_internal_request(
[basic_chat_request], ["test-id"]
)
)
assert isinstance(adapted_request, GenerateReqInput)
assert adapted_request.stream == basic_chat_request.stream
assert processed_request == basic_chat_request
class TestToolCalls:
"""Test tool call functionality from adapter.py"""
def test_tool_call_request_conversion(self, serving_chat):
"""Test request with tool calls"""
request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "What's the weather?"}],
tools=[
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather information",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
},
},
}
],
tool_choice="auto",
)
with patch.object(serving_chat, "_process_messages") as mock_process:
mock_process.return_value = (
"Test prompt",
[1, 2, 3],
None,
None,
[],
["</s>"],
None, # tool_call_constraint
)
adapted_request, _ = serving_chat._convert_to_internal_request(
[request], ["test-id"]
)
assert adapted_request.rid == "test-id"
# Tool call constraint should be processed
assert request.tools is not None
def test_tool_choice_none(self, serving_chat):
"""Test tool_choice=none disables tool calls"""
request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Hello"}],
tools=[{"type": "function", "function": {"name": "test_func"}}],
tool_choice="none",
)
with patch.object(serving_chat, "_process_messages") as mock_process:
mock_process.return_value = (
"Test prompt",
[1, 2, 3],
None,
None,
[],
["</s>"],
None, # tool_call_constraint
)
adapted_request, _ = serving_chat._convert_to_internal_request(
[request], ["test-id"]
)
# Tools should not be processed when tool_choice is "none"
assert adapted_request.rid == "test-id"
def test_tool_call_response_processing(self, serving_chat):
"""Test processing tool calls in response"""
mock_ret_item = {
"text": '{"name": "get_weather", "parameters": {"location": "Paris"}}',
"meta_info": {
"output_token_logprobs": [],
"output_top_logprobs": None,
},
}
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
},
},
}
]
finish_reason = {"type": "stop", "matched": None}
# Mock FunctionCallParser
with patch(
"sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser"
) as mock_parser_class:
mock_parser = Mock()
mock_parser.has_tool_call.return_value = True
# Create proper mock tool call object
mock_tool_call = Mock()
mock_tool_call.name = "get_weather"
mock_tool_call.parameters = '{"location": "Paris"}'
mock_parser.parse_non_stream.return_value = ("", [mock_tool_call])
mock_parser_class.return_value = mock_parser
tool_calls, text, updated_finish_reason = serving_chat._process_tool_calls(
mock_ret_item["text"], tools, "hermes", finish_reason
)
assert tool_calls is not None
assert len(tool_calls) == 1
assert updated_finish_reason["type"] == "tool_calls"
class TestMultimodalContent:
"""Test multimodal content handling from adapter.py"""
def test_multimodal_request_with_images(self, serving_chat):
"""Test request with image content"""
request = ChatCompletionRequest(
model="test-model",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,..."},
},
],
}
],
)
# Set multimodal mode
serving_chat.tokenizer_manager.model_config.is_multimodal = True
with patch.object(serving_chat, "_apply_jinja_template") as mock_apply:
mock_apply.return_value = (
"prompt",
[1, 2, 3],
["image_data"],
None,
[],
[],
)
with patch.object(
serving_chat, "_apply_conversation_template"
) as mock_conv:
mock_conv.return_value = ("prompt", ["image_data"], None, [], [])
(
prompt,
prompt_ids,
image_data,
audio_data,
modalities,
stop,
tool_call_constraint,
) = serving_chat._process_messages(request, True)
assert image_data == ["image_data"]
assert prompt == "prompt"
def test_multimodal_request_with_audio(self, serving_chat):
"""Test request with audio content"""
request = ChatCompletionRequest(
model="test-model",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "Transcribe this audio"},
{
"type": "audio_url",
"audio_url": {"url": "data:audio/wav;base64,UklGR..."},
},
],
}
],
)
serving_chat.tokenizer_manager.model_config.is_multimodal = True
with patch.object(serving_chat, "_apply_jinja_template") as mock_apply:
mock_apply.return_value = (
"prompt",
[1, 2, 3],
None,
["audio_data"],
["audio"],
[],
)
with patch.object(
serving_chat, "_apply_conversation_template"
) as mock_conv:
mock_conv.return_value = ("prompt", None, ["audio_data"], ["audio"], [])
(
prompt,
prompt_ids,
image_data,
audio_data,
modalities,
stop,
tool_call_constraint,
) = serving_chat._process_messages(request, True)
assert audio_data == ["audio_data"]
assert modalities == ["audio"]
class TestTemplateHandling:
"""Test chat template handling from adapter.py"""
def test_jinja_template_processing(self, serving_chat):
"""Test Jinja template processing"""
request = ChatCompletionRequest(
model="test-model", messages=[{"role": "user", "content": "Hello"}]
)
# Mock the template attribute directly
serving_chat.tokenizer_manager.chat_template_name = None
serving_chat.tokenizer_manager.tokenizer.chat_template = "<jinja_template>"
with patch.object(serving_chat, "_apply_jinja_template") as mock_apply:
mock_apply.return_value = (
"processed_prompt",
[1, 2, 3],
None,
None,
[],
["</s>"],
)
# Mock hasattr to simulate the None check
with patch("builtins.hasattr") as mock_hasattr:
mock_hasattr.return_value = True
(
prompt,
prompt_ids,
image_data,
audio_data,
modalities,
stop,
tool_call_constraint,
) = serving_chat._process_messages(request, False)
assert prompt == "processed_prompt"
assert prompt_ids == [1, 2, 3]
def test_conversation_template_processing(self, serving_chat):
"""Test conversation template processing"""
request = ChatCompletionRequest(
model="test-model", messages=[{"role": "user", "content": "Hello"}]
)
serving_chat.tokenizer_manager.chat_template_name = "llama-3"
with patch.object(serving_chat, "_apply_conversation_template") as mock_apply:
mock_apply.return_value = ("conv_prompt", None, None, [], ["</s>"])
(
prompt,
prompt_ids,
image_data,
audio_data,
modalities,
stop,
tool_call_constraint,
) = serving_chat._process_messages(request, False)
assert prompt == "conv_prompt"
assert stop == ["</s>"]
def test_continue_final_message(self, serving_chat):
"""Test continue_final_message functionality"""
request = ChatCompletionRequest(
model="test-model",
messages=[
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there"},
],
continue_final_message=True,
)
with patch.object(serving_chat, "_apply_conversation_template") as mock_apply:
mock_apply.return_value = ("Hi there", None, None, [], ["</s>"])
(
prompt,
prompt_ids,
image_data,
audio_data,
modalities,
stop,
tool_call_constraint,
) = serving_chat._process_messages(request, False)
# Should handle continue_final_message properly
assert prompt == "Hi there"
class TestReasoningContent:
"""Test reasoning content separation from adapter.py"""
def test_reasoning_content_request(self, serving_chat):
"""Test request with reasoning content separation"""
request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Solve this math problem"}],
separate_reasoning=True,
stream_reasoning=False,
)
with patch.object(serving_chat, "_process_messages") as mock_process:
mock_process.return_value = (
"Test prompt",
[1, 2, 3],
None,
None,
[],
["</s>"],
None, # tool_call_constraint
)
adapted_request, _ = serving_chat._convert_to_internal_request(
[request], ["test-id"]
)
assert adapted_request.rid == "test-id"
assert request.separate_reasoning == True
def test_reasoning_content_response(self, serving_chat):
"""Test reasoning content in response"""
mock_ret_item = {
"text": "<thinking>This is reasoning</thinking>Answer: 42",
"meta_info": {
"output_token_logprobs": [],
"output_top_logprobs": None,
},
}
# Mock ReasoningParser
with patch(
"sglang.srt.entrypoints.openai.serving_chat.ReasoningParser"
) as mock_parser_class:
mock_parser = Mock()
mock_parser.parse_non_stream.return_value = (
"This is reasoning",
"Answer: 42",
)
mock_parser_class.return_value = mock_parser
choice_logprobs = None
reasoning_text = None
text = mock_ret_item["text"]
# Simulate reasoning processing
enable_thinking = True
if enable_thinking:
parser = mock_parser_class(model_type="test", stream_reasoning=False)
reasoning_text, text = parser.parse_non_stream(text)
assert reasoning_text == "This is reasoning"
assert text == "Answer: 42"
class TestSamplingParams:
"""Test sampling parameter handling from adapter.py"""
def test_all_sampling_parameters(self, serving_chat):
"""Test all sampling parameters are properly handled"""
request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Hello"}],
temperature=0.8,
max_tokens=150,
max_completion_tokens=200,
min_tokens=5,
top_p=0.9,
top_k=50,
min_p=0.1,
presence_penalty=0.1,
frequency_penalty=0.2,
repetition_penalty=1.1,
stop=["<|endoftext|>"],
stop_token_ids=[13, 14],
regex=r"\d+",
ebnf="<expr> ::= <number>",
n=2,
no_stop_trim=True,
ignore_eos=True,
skip_special_tokens=False,
logit_bias={"1": 0.5, "2": -0.3},
)
with patch.object(serving_chat, "_process_messages") as mock_process:
mock_process.return_value = (
"Test prompt",
[1, 2, 3],
None,
None,
[],
["</s>"],
None, # tool_call_constraint
)
sampling_params = serving_chat._build_sampling_params(
request, ["</s>"], None
)
# Verify all parameters
assert sampling_params["temperature"] == 0.8
assert sampling_params["max_new_tokens"] == 150
assert sampling_params["min_new_tokens"] == 5
assert sampling_params["top_p"] == 0.9
assert sampling_params["top_k"] == 50
assert sampling_params["min_p"] == 0.1
assert sampling_params["presence_penalty"] == 0.1
assert sampling_params["frequency_penalty"] == 0.2
assert sampling_params["repetition_penalty"] == 1.1
assert sampling_params["stop"] == ["</s>"]
assert sampling_params["logit_bias"] == {"1": 0.5, "2": -0.3}
def test_response_format_json_schema(self, serving_chat):
"""Test response format with JSON schema"""
request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Generate JSON"}],
response_format={
"type": "json_schema",
"json_schema": {
"name": "response",
"schema": {
"type": "object",
"properties": {"answer": {"type": "string"}},
},
},
},
)
with patch.object(serving_chat, "_process_messages") as mock_process:
mock_process.return_value = (
"Test prompt",
[1, 2, 3],
None,
None,
[],
["</s>"],
None, # tool_call_constraint
)
sampling_params = serving_chat._build_sampling_params(
request, ["</s>"], None
)
assert "json_schema" in sampling_params
assert '"type": "object"' in sampling_params["json_schema"]
def test_response_format_json_object(self, serving_chat):
"""Test response format with JSON object"""
request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Generate JSON"}],
response_format={"type": "json_object"},
)
with patch.object(serving_chat, "_process_messages") as mock_process:
mock_process.return_value = (
"Test prompt",
[1, 2, 3],
None,
None,
[],
["</s>"],
None, # tool_call_constraint
)
sampling_params = serving_chat._build_sampling_params(
request, ["</s>"], None
)
assert sampling_params["json_schema"] == '{"type": "object"}'
"""
Tests for the refactored completions serving handler
"""
from unittest.mock import AsyncMock, Mock, patch
import pytest
from sglang.srt.entrypoints.openai.protocol import (
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionStreamResponse,
ErrorResponse,
)
from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.tokenizer_manager import TokenizerManager
@pytest.fixture
def mock_tokenizer_manager():
"""Create a mock tokenizer manager"""
manager = Mock(spec=TokenizerManager)
# Mock tokenizer
manager.tokenizer = Mock()
manager.tokenizer.encode = Mock(return_value=[1, 2, 3, 4])
manager.tokenizer.decode = Mock(return_value="decoded text")
manager.tokenizer.bos_token_id = 1
# Mock model config
manager.model_config = Mock()
manager.model_config.is_multimodal = False
# Mock server args
manager.server_args = Mock()
manager.server_args.enable_cache_report = False
# Mock generation
manager.generate_request = AsyncMock()
manager.create_abort_task = Mock(return_value=None)
return manager
@pytest.fixture
def serving_completion(mock_tokenizer_manager):
"""Create a OpenAIServingCompletion instance"""
return OpenAIServingCompletion(mock_tokenizer_manager)
class TestPromptHandling:
"""Test different prompt types and formats from adapter.py"""
def test_single_string_prompt(self, serving_completion):
"""Test handling single string prompt"""
request = CompletionRequest(
model="test-model", prompt="Hello world", max_tokens=100
)
adapted_request, _ = serving_completion._convert_to_internal_request(
[request], ["test-id"]
)
assert adapted_request.text == "Hello world"
def test_single_token_ids_prompt(self, serving_completion):
"""Test handling single token IDs prompt"""
request = CompletionRequest(
model="test-model", prompt=[1, 2, 3, 4], max_tokens=100
)
adapted_request, _ = serving_completion._convert_to_internal_request(
[request], ["test-id"]
)
assert adapted_request.input_ids == [1, 2, 3, 4]
def test_completion_template_handling(self, serving_completion):
"""Test completion template processing"""
request = CompletionRequest(
model="test-model",
prompt="def hello():",
suffix="return 'world'",
max_tokens=100,
)
with patch(
"sglang.srt.entrypoints.openai.serving_completions.is_completion_template_defined",
return_value=True,
):
with patch(
"sglang.srt.entrypoints.openai.serving_completions.generate_completion_prompt_from_request",
return_value="processed_prompt",
):
adapted_request, _ = serving_completion._convert_to_internal_request(
[request], ["test-id"]
)
assert adapted_request.text == "processed_prompt"
class TestEchoHandling:
"""Test echo functionality from adapter.py"""
def test_echo_with_string_prompt_streaming(self, serving_completion):
"""Test echo handling with string prompt in streaming"""
request = CompletionRequest(
model="test-model", prompt="Hello", max_tokens=100, echo=True
)
# Test _get_echo_text method
echo_text = serving_completion._get_echo_text(request, 0)
assert echo_text == "Hello"
def test_echo_with_list_of_strings_streaming(self, serving_completion):
"""Test echo handling with list of strings in streaming"""
request = CompletionRequest(
model="test-model",
prompt=["Hello", "World"],
max_tokens=100,
echo=True,
n=1,
)
echo_text = serving_completion._get_echo_text(request, 0)
assert echo_text == "Hello"
echo_text = serving_completion._get_echo_text(request, 1)
assert echo_text == "World"
def test_echo_with_token_ids_streaming(self, serving_completion):
"""Test echo handling with token IDs in streaming"""
request = CompletionRequest(
model="test-model", prompt=[1, 2, 3], max_tokens=100, echo=True
)
serving_completion.tokenizer_manager.tokenizer.decode.return_value = (
"decoded_prompt"
)
echo_text = serving_completion._get_echo_text(request, 0)
assert echo_text == "decoded_prompt"
def test_echo_with_multiple_token_ids_streaming(self, serving_completion):
"""Test echo handling with multiple token ID prompts in streaming"""
request = CompletionRequest(
model="test-model", prompt=[[1, 2], [3, 4]], max_tokens=100, echo=True, n=1
)
serving_completion.tokenizer_manager.tokenizer.decode.return_value = "decoded"
echo_text = serving_completion._get_echo_text(request, 0)
assert echo_text == "decoded"
def test_prepare_echo_prompts_non_streaming(self, serving_completion):
"""Test prepare echo prompts for non-streaming response"""
# Test with single string
request = CompletionRequest(model="test-model", prompt="Hello", echo=True)
echo_prompts = serving_completion._prepare_echo_prompts(request)
assert echo_prompts == ["Hello"]
# Test with list of strings
request = CompletionRequest(
model="test-model", prompt=["Hello", "World"], echo=True
)
echo_prompts = serving_completion._prepare_echo_prompts(request)
assert echo_prompts == ["Hello", "World"]
# Test with token IDs
request = CompletionRequest(model="test-model", prompt=[1, 2, 3], echo=True)
serving_completion.tokenizer_manager.tokenizer.decode.return_value = "decoded"
echo_prompts = serving_completion._prepare_echo_prompts(request)
assert echo_prompts == ["decoded"]
"""
Unit tests for the OpenAIServingEmbedding class from serving_embedding.py.
These tests ensure that the embedding serving implementation maintains compatibility
with the original adapter.py functionality and follows OpenAI API specifications.
"""
import asyncio
import json
import time
import uuid
from typing import Any, Dict, List
from unittest.mock import AsyncMock, Mock, patch
import pytest
from fastapi import Request
from fastapi.responses import ORJSONResponse
from pydantic_core import ValidationError
from sglang.srt.entrypoints.openai.protocol import (
EmbeddingObject,
EmbeddingRequest,
EmbeddingResponse,
ErrorResponse,
MultimodalEmbeddingInput,
UsageInfo,
)
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from sglang.srt.managers.io_struct import EmbeddingReqInput
# Mock TokenizerManager for embedding tests
class MockTokenizerManager:
def __init__(self):
self.model_config = Mock()
self.model_config.is_multimodal = False
self.server_args = Mock()
self.server_args.enable_cache_report = False
self.model_path = "test-model"
# Mock tokenizer
self.tokenizer = Mock()
self.tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5])
self.tokenizer.decode = Mock(return_value="Test embedding input")
self.tokenizer.chat_template = None
self.tokenizer.bos_token_id = 1
# Mock generate_request method for embeddings
async def mock_generate_embedding():
yield {
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5] * 20, # 100-dim embedding
"meta_info": {
"id": f"embd-{uuid.uuid4()}",
"prompt_tokens": 5,
},
}
self.generate_request = Mock(return_value=mock_generate_embedding())
@pytest.fixture
def mock_tokenizer_manager():
"""Create a mock tokenizer manager for testing."""
return MockTokenizerManager()
@pytest.fixture
def serving_embedding(mock_tokenizer_manager):
"""Create an OpenAIServingEmbedding instance for testing."""
return OpenAIServingEmbedding(mock_tokenizer_manager)
@pytest.fixture
def mock_request():
"""Create a mock FastAPI request."""
request = Mock(spec=Request)
request.headers = {}
return request
@pytest.fixture
def basic_embedding_request():
"""Create a basic embedding request."""
return EmbeddingRequest(
model="test-model",
input="Hello, how are you?",
encoding_format="float",
)
@pytest.fixture
def list_embedding_request():
"""Create an embedding request with list input."""
return EmbeddingRequest(
model="test-model",
input=["Hello, how are you?", "I am fine, thank you!"],
encoding_format="float",
)
@pytest.fixture
def multimodal_embedding_request():
"""Create a multimodal embedding request."""
return EmbeddingRequest(
model="test-model",
input=[
MultimodalEmbeddingInput(text="Hello", image="base64_image_data"),
MultimodalEmbeddingInput(text="World", image=None),
],
encoding_format="float",
)
@pytest.fixture
def token_ids_embedding_request():
"""Create an embedding request with token IDs."""
return EmbeddingRequest(
model="test-model",
input=[1, 2, 3, 4, 5],
encoding_format="float",
)
class TestOpenAIServingEmbeddingConversion:
"""Test request conversion methods."""
def test_convert_single_string_request(
self, serving_embedding, basic_embedding_request
):
"""Test converting single string request to internal format."""
adapted_request, processed_request = (
serving_embedding._convert_to_internal_request(
[basic_embedding_request], ["test-id"]
)
)
assert isinstance(adapted_request, EmbeddingReqInput)
assert adapted_request.text == "Hello, how are you?"
assert adapted_request.rid == "test-id"
assert processed_request == basic_embedding_request
def test_convert_list_string_request(
self, serving_embedding, list_embedding_request
):
"""Test converting list of strings request to internal format."""
adapted_request, processed_request = (
serving_embedding._convert_to_internal_request(
[list_embedding_request], ["test-id"]
)
)
assert isinstance(adapted_request, EmbeddingReqInput)
assert adapted_request.text == ["Hello, how are you?", "I am fine, thank you!"]
assert adapted_request.rid == "test-id"
assert processed_request == list_embedding_request
def test_convert_token_ids_request(
self, serving_embedding, token_ids_embedding_request
):
"""Test converting token IDs request to internal format."""
adapted_request, processed_request = (
serving_embedding._convert_to_internal_request(
[token_ids_embedding_request], ["test-id"]
)
)
assert isinstance(adapted_request, EmbeddingReqInput)
assert adapted_request.input_ids == [1, 2, 3, 4, 5]
assert adapted_request.rid == "test-id"
assert processed_request == token_ids_embedding_request
def test_convert_multimodal_request(
self, serving_embedding, multimodal_embedding_request
):
"""Test converting multimodal request to internal format."""
adapted_request, processed_request = (
serving_embedding._convert_to_internal_request(
[multimodal_embedding_request], ["test-id"]
)
)
assert isinstance(adapted_request, EmbeddingReqInput)
# Should extract text and images separately
assert len(adapted_request.text) == 2
assert "Hello" in adapted_request.text
assert "World" in adapted_request.text
assert adapted_request.image_data[0] == "base64_image_data"
assert adapted_request.image_data[1] is None
assert adapted_request.rid == "test-id"
class TestEmbeddingResponseBuilding:
"""Test response building methods."""
def test_build_single_embedding_response(self, serving_embedding):
"""Test building response for single embedding."""
ret_data = [
{
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
"meta_info": {"prompt_tokens": 5},
}
]
response = serving_embedding._build_embedding_response(ret_data, "test-model")
assert isinstance(response, EmbeddingResponse)
assert response.model == "test-model"
assert len(response.data) == 1
assert response.data[0].embedding == [0.1, 0.2, 0.3, 0.4, 0.5]
assert response.data[0].index == 0
assert response.data[0].object == "embedding"
assert response.usage.prompt_tokens == 5
assert response.usage.total_tokens == 5
assert response.usage.completion_tokens == 0
def test_build_multiple_embedding_response(self, serving_embedding):
"""Test building response for multiple embeddings."""
ret_data = [
{
"embedding": [0.1, 0.2, 0.3],
"meta_info": {"prompt_tokens": 3},
},
{
"embedding": [0.4, 0.5, 0.6],
"meta_info": {"prompt_tokens": 4},
},
]
response = serving_embedding._build_embedding_response(ret_data, "test-model")
assert isinstance(response, EmbeddingResponse)
assert len(response.data) == 2
assert response.data[0].embedding == [0.1, 0.2, 0.3]
assert response.data[0].index == 0
assert response.data[1].embedding == [0.4, 0.5, 0.6]
assert response.data[1].index == 1
assert response.usage.prompt_tokens == 7 # 3 + 4
assert response.usage.total_tokens == 7
@pytest.mark.asyncio
class TestOpenAIServingEmbeddingAsyncMethods:
"""Test async methods of OpenAIServingEmbedding."""
async def test_handle_request_success(
self, serving_embedding, basic_embedding_request, mock_request
):
"""Test successful embedding request handling."""
# Mock the generate_request to return expected data
async def mock_generate():
yield {
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
"meta_info": {"prompt_tokens": 5},
}
serving_embedding.tokenizer_manager.generate_request = Mock(
return_value=mock_generate()
)
response = await serving_embedding.handle_request(
basic_embedding_request, mock_request
)
assert isinstance(response, EmbeddingResponse)
assert len(response.data) == 1
assert response.data[0].embedding == [0.1, 0.2, 0.3, 0.4, 0.5]
async def test_handle_request_validation_error(
self, serving_embedding, mock_request
):
"""Test handling request with validation error."""
invalid_request = EmbeddingRequest(model="test-model", input="")
response = await serving_embedding.handle_request(invalid_request, mock_request)
assert isinstance(response, ORJSONResponse)
assert response.status_code == 400
async def test_handle_request_generation_error(
self, serving_embedding, basic_embedding_request, mock_request
):
"""Test handling request with generation error."""
# Mock generate_request to raise an error
async def mock_generate_error():
raise ValueError("Generation failed")
yield # This won't be reached but needed for async generator
serving_embedding.tokenizer_manager.generate_request = Mock(
return_value=mock_generate_error()
)
response = await serving_embedding.handle_request(
basic_embedding_request, mock_request
)
assert isinstance(response, ORJSONResponse)
assert response.status_code == 400
async def test_handle_request_internal_error(
self, serving_embedding, basic_embedding_request, mock_request
):
"""Test handling request with internal server error."""
# Mock _convert_to_internal_request to raise an exception
with patch.object(
serving_embedding,
"_convert_to_internal_request",
side_effect=Exception("Internal error"),
):
response = await serving_embedding.handle_request(
basic_embedding_request, mock_request
)
assert isinstance(response, ORJSONResponse)
assert response.status_code == 500
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment