Unverified Commit 72676cd6 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

feat(oai refactor): Replace `openai_api` with `entrypoints/openai` (#7351)


Co-authored-by: default avatarJin Pan <jpan236@wisc.edu>
parent 02bf31ef
......@@ -2,6 +2,7 @@ import json
import logging
from typing import List
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
......@@ -9,7 +10,6 @@ from sglang.srt.function_call.core_types import (
_GetInfoFunc,
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer
from sglang.srt.openai_api.protocol import Tool
logger = logging.getLogger(__name__)
......
......@@ -3,6 +3,7 @@ import logging
import re
from typing import List
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
......@@ -10,7 +11,6 @@ from sglang.srt.function_call.core_types import (
_GetInfoFunc,
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer
from sglang.srt.openai_api.protocol import Tool
logger = logging.getLogger(__name__)
......
......@@ -4,6 +4,7 @@ import logging
import re
from typing import List, Optional
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
......@@ -12,7 +13,6 @@ from sglang.srt.function_call.core_types import (
_GetInfoFunc,
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer
from sglang.srt.openai_api.protocol import Tool
logger = logging.getLogger(__name__)
......
......@@ -3,6 +3,7 @@ import logging
import re
from typing import List
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
......@@ -10,7 +11,6 @@ from sglang.srt.function_call.core_types import (
_GetInfoFunc,
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer
from sglang.srt.openai_api.protocol import Tool
logger = logging.getLogger(__name__)
......
"""
Utility functions for OpenAI API adapter.
"""Template utilities for Jinja template processing.
This module provides utilities for analyzing and processing Jinja chat templates,
including content format detection and message processing.
"""
import logging
from typing import Dict, List
import jinja2.nodes
import jinja2
import transformers.utils.chat_template_utils as hf_chat_utils
logger = logging.getLogger(__name__)
......@@ -75,7 +76,7 @@ def _try_extract_ast(chat_template: str):
return None
def detect_template_content_format(chat_template: str) -> str:
def detect_jinja_template_content_format(chat_template: str) -> str:
"""
Detect whether a chat template expects 'string' or 'openai' content format.
......
......@@ -864,12 +864,6 @@ class SetInternalStateReq:
server_args: Dict[str, Any]
@dataclass
class V1RerankReqInput:
query: str
documents: List[str]
@dataclass
class SetInternalStateReqOutput:
updated: bool
......
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Centralized template management for chat templates and completion templates.
This module provides a unified interface for managing both chat conversation templates
and code completion templates, eliminating global state and improving modularity.
"""
import json
import logging
import os
from typing import Optional
from sglang.srt.code_completion_parser import (
CompletionTemplate,
FimPosition,
completion_template_exists,
register_completion_template,
)
from sglang.srt.conversation import (
Conversation,
SeparatorStyle,
chat_template_exists,
get_conv_template_by_model_path,
register_conv_template,
)
from sglang.srt.jinja_template_utils import detect_jinja_template_content_format
logger = logging.getLogger(__name__)
class TemplateManager:
"""
Centralized manager for chat and completion templates.
This class encapsulates all template-related state and operations,
eliminating the need for global variables and providing a clean
interface for template management.
"""
def __init__(self):
self._chat_template_name: Optional[str] = None
self._completion_template_name: Optional[str] = None
self._jinja_template_content_format: Optional[str] = None
@property
def chat_template_name(self) -> Optional[str]:
"""Get the current chat template name."""
return self._chat_template_name
@property
def completion_template_name(self) -> Optional[str]:
"""Get the current completion template name."""
return self._completion_template_name
@property
def jinja_template_content_format(self) -> Optional[str]:
"""Get the detected template content format ('string' or 'openai' or None)."""
return self._jinja_template_content_format
def load_chat_template(
self, tokenizer_manager, chat_template_arg: str, model_path: str
) -> None:
"""
Load a chat template from various sources.
Args:
tokenizer_manager: The tokenizer manager instance
chat_template_arg: Template name or file path
model_path: Path to the model
"""
logger.info(f"Loading chat template: {chat_template_arg}")
if not chat_template_exists(chat_template_arg):
if not os.path.exists(chat_template_arg):
raise RuntimeError(
f"Chat template {chat_template_arg} is not a built-in template name "
"or a valid chat template file path."
)
if chat_template_arg.endswith(".jinja"):
self._load_jinja_template(tokenizer_manager, chat_template_arg)
else:
self._load_json_chat_template(chat_template_arg)
else:
self._chat_template_name = chat_template_arg
def guess_chat_template_from_model_path(self, model_path: str) -> None:
"""
Infer chat template name from model path.
Args:
model_path: Path to the model
"""
template_name = get_conv_template_by_model_path(model_path)
if template_name is not None:
logger.info(f"Inferred chat template from model path: {template_name}")
self._chat_template_name = template_name
def load_completion_template(self, completion_template_arg: str) -> None:
"""
Load completion template for code completion.
Args:
completion_template_arg: Template name or file path
"""
logger.info(f"Loading completion template: {completion_template_arg}")
if not completion_template_exists(completion_template_arg):
if not os.path.exists(completion_template_arg):
raise RuntimeError(
f"Completion template {completion_template_arg} is not a built-in template name "
"or a valid completion template file path."
)
self._load_json_completion_template(completion_template_arg)
else:
self._completion_template_name = completion_template_arg
def initialize_templates(
self,
tokenizer_manager,
model_path: str,
chat_template: Optional[str] = None,
completion_template: Optional[str] = None,
) -> None:
"""
Initialize all templates based on provided configuration.
Args:
tokenizer_manager: The tokenizer manager instance
model_path: Path to the model
chat_template: Optional chat template name/path
completion_template: Optional completion template name/path
"""
# Load chat template
if chat_template:
self.load_chat_template(tokenizer_manager, chat_template, model_path)
else:
self.guess_chat_template_from_model_path(model_path)
# Load completion template
if completion_template:
self.load_completion_template(completion_template)
def _load_jinja_template(self, tokenizer_manager, template_path: str) -> None:
"""Load a Jinja template file."""
with open(template_path, "r") as f:
chat_template = "".join(f.readlines()).strip("\n")
tokenizer_manager.tokenizer.chat_template = chat_template.replace("\\n", "\n")
self._chat_template_name = None
# Detect content format from the loaded template
self._jinja_template_content_format = detect_jinja_template_content_format(
chat_template
)
logger.info(
f"Detected chat template content format: {self._jinja_template_content_format}"
)
def _load_json_chat_template(self, template_path: str) -> None:
"""Load a JSON chat template file."""
assert template_path.endswith(
".json"
), "unrecognized format of chat template file"
with open(template_path, "r") as filep:
template = json.load(filep)
try:
sep_style = SeparatorStyle[template["sep_style"]]
except KeyError:
raise ValueError(
f"Unknown separator style: {template['sep_style']}"
) from None
register_conv_template(
Conversation(
name=template["name"],
system_template=template["system"] + "\n{system_message}",
system_message=template.get("system_message", ""),
roles=(template["user"], template["assistant"]),
sep_style=sep_style,
sep=template.get("sep", "\n"),
stop_str=template["stop_str"],
),
override=True,
)
self._chat_template_name = template["name"]
def _load_json_completion_template(self, template_path: str) -> None:
"""Load a JSON completion template file."""
assert template_path.endswith(
".json"
), "unrecognized format of completion template file"
with open(template_path, "r") as filep:
template = json.load(filep)
try:
fim_position = FimPosition[template["fim_position"]]
except KeyError:
raise ValueError(
f"Unknown fim position: {template['fim_position']}"
) from None
register_completion_template(
CompletionTemplate(
name=template["name"],
fim_begin_token=template["fim_begin_token"],
fim_middle_token=template["fim_middle_token"],
fim_end_token=template["fim_end_token"],
fim_position=fim_position,
),
override=True,
)
self._completion_template_name = template["name"]
......@@ -1058,12 +1058,7 @@ class TokenizerManager:
"lora_path",
]
)
out_skip_names = set(
[
"text",
"output_ids",
]
)
out_skip_names = set(["text", "output_ids", "embedding"])
elif self.log_requests_level == 1:
max_length = 2048
elif self.log_requests_level == 2:
......
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Conversion between OpenAI APIs and native SRT APIs"""
import asyncio
import base64
import json
import logging
import os
import time
import uuid
from http import HTTPStatus
from typing import Dict, List
from fastapi import HTTPException, Request, UploadFile
from fastapi.responses import ORJSONResponse, StreamingResponse
from pydantic import ValidationError
from sglang.srt.code_completion_parser import (
generate_completion_prompt_from_request,
is_completion_template_defined,
)
from sglang.srt.conversation import (
Conversation,
SeparatorStyle,
chat_template_exists,
generate_chat_conv,
generate_embedding_convs,
get_conv_template_by_model_path,
register_conv_template,
)
from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.managers.io_struct import (
EmbeddingReqInput,
GenerateReqInput,
V1RerankReqInput,
)
from sglang.srt.openai_api.protocol import (
BatchRequest,
BatchResponse,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatCompletionTokenLogprob,
ChatMessage,
ChoiceLogprobs,
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
DeltaMessage,
EmbeddingObject,
EmbeddingRequest,
EmbeddingResponse,
ErrorResponse,
FileDeleteResponse,
FileRequest,
FileResponse,
FunctionResponse,
LogProbs,
MultimodalEmbeddingInput,
RerankResponse,
ScoringRequest,
ScoringResponse,
ToolCall,
TopLogprob,
UsageInfo,
)
from sglang.srt.openai_api.utils import (
detect_template_content_format,
process_content_for_template_format,
)
from sglang.srt.reasoning_parser import ReasoningParser
from sglang.utils import convert_json_schema_to_str, get_exception_traceback
logger = logging.getLogger(__name__)
chat_template_name = None
# Global cache for template content format detection (one model/template per instance)
# NOTE: A better approach would be to initialize the chat template format when the endpoint is created
_cached_chat_template = None
_cached_template_format = None
class FileMetadata:
def __init__(self, filename: str, purpose: str):
self.filename = filename
self.purpose = purpose
# In-memory storage for batch jobs and files
batch_storage: Dict[str, BatchResponse] = {}
file_id_request: Dict[str, FileMetadata] = {}
file_id_response: Dict[str, FileResponse] = {}
# map file id to file path in SGLang backend
file_id_storage: Dict[str, str] = {}
# backend storage directory
storage_dir = None
def create_error_response(
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
):
error = ErrorResponse(message=message, type=err_type, code=status_code.value)
return ORJSONResponse(content=error.model_dump(), status_code=error.code)
def create_streaming_error_response(
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
) -> str:
error = ErrorResponse(message=message, type=err_type, code=status_code.value)
json_str = json.dumps({"error": error.model_dump()})
return json_str
def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg, model_path):
global chat_template_name
logger.info(
f"Use chat template for the OpenAI-compatible API server: {chat_template_arg}"
)
if not chat_template_exists(chat_template_arg):
if not os.path.exists(chat_template_arg):
raise RuntimeError(
f"Chat template {chat_template_arg} is not a built-in template name "
"or a valid chat template file path."
)
if chat_template_arg.endswith(".jinja"):
with open(chat_template_arg, "r") as f:
chat_template = "".join(f.readlines()).strip("\n")
tokenizer_manager.tokenizer.chat_template = chat_template.replace(
"\\n", "\n"
)
chat_template_name = None
else:
assert chat_template_arg.endswith(
".json"
), "unrecognized format of chat template file"
with open(chat_template_arg, "r") as filep:
template = json.load(filep)
try:
sep_style = SeparatorStyle[template["sep_style"]]
except KeyError:
raise ValueError(
f"Unknown separator style: {template['sep_style']}"
) from None
register_conv_template(
Conversation(
name=template["name"],
system_template=template["system"] + "\n{system_message}",
system_message=template.get("system_message", ""),
roles=(template["user"], template["assistant"]),
sep_style=sep_style,
sep=template.get("sep", "\n"),
stop_str=template["stop_str"],
),
override=True,
)
chat_template_name = template["name"]
else:
chat_template_name = chat_template_arg
def guess_chat_template_name_from_model_path(model_path):
global chat_template_name
chat_template_name = get_conv_template_by_model_path(model_path)
if chat_template_name is not None:
logger.info(
f"Infer the chat template name from the model path and obtain the result: {chat_template_name}."
)
def _validate_prompt(prompt: str):
"""Validate that the prompt is not empty or whitespace only."""
is_invalid = False
# Check for empty/whitespace string
if isinstance(prompt, str):
is_invalid = not prompt.strip()
# Check for various invalid list cases: [], [""], [" "], [[]]
elif isinstance(prompt, list):
is_invalid = not prompt or (
len(prompt) == 1
and (
(isinstance(prompt[0], str) and not prompt[0].strip())
or (isinstance(prompt[0], list) and not prompt[0])
)
)
if is_invalid:
raise HTTPException(
status_code=400,
detail="Input cannot be empty or contain only whitespace.",
)
return prompt
async def v1_files_create(
file: UploadFile, purpose: str, file_storage_path: str = None
):
try:
global storage_dir
if file_storage_path:
storage_dir = file_storage_path
# Read the file content
file_content = await file.read()
# Create an instance of RequestBody
request_body = FileRequest(file=file_content, purpose=purpose)
# Save the file to the sglang_oai_storage directory
os.makedirs(storage_dir, exist_ok=True)
file_id = f"backend_input_file-{uuid.uuid4()}"
filename = f"{file_id}.jsonl"
file_path = os.path.join(storage_dir, filename)
with open(file_path, "wb") as f:
f.write(request_body.file)
# add info to global file map
file_id_request[file_id] = FileMetadata(filename=file.filename, purpose=purpose)
file_id_storage[file_id] = file_path
# Return the response in the required format
response = FileResponse(
id=file_id,
bytes=len(request_body.file),
created_at=int(time.time()),
filename=file.filename,
purpose=request_body.purpose,
)
file_id_response[file_id] = response
return response
except ValidationError as e:
return {"error": "Invalid input", "details": e.errors()}
async def v1_delete_file(file_id: str):
# Retrieve the file job from the in-memory storage
file_response = file_id_response.get(file_id)
if file_response is None:
raise HTTPException(status_code=404, detail="File not found")
file_path = file_id_storage.get(file_id)
if file_path is None:
raise HTTPException(status_code=404, detail="File not found")
os.remove(file_path)
del file_id_response[file_id]
del file_id_storage[file_id]
return FileDeleteResponse(id=file_id, deleted=True)
async def v1_batches(tokenizer_manager, raw_request: Request):
try:
body = await raw_request.json()
batch_request = BatchRequest(**body)
batch_id = f"batch_{uuid.uuid4()}"
# Create an instance of BatchResponse
batch_response = BatchResponse(
id=batch_id,
endpoint=batch_request.endpoint,
input_file_id=batch_request.input_file_id,
completion_window=batch_request.completion_window,
created_at=int(time.time()),
metadata=batch_request.metadata,
)
batch_storage[batch_id] = batch_response
# Start processing the batch asynchronously
asyncio.create_task(process_batch(tokenizer_manager, batch_id, batch_request))
# Return the initial batch_response
return batch_response
except ValidationError as e:
return {"error": "Invalid input", "details": e.errors()}
except Exception as e:
return {"error": str(e)}
async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRequest):
try:
# Update the batch status to "in_progress"
batch_storage[batch_id].status = "in_progress"
batch_storage[batch_id].in_progress_at = int(time.time())
# Retrieve the input file content
input_file_request = file_id_request.get(batch_request.input_file_id)
if not input_file_request:
raise ValueError("Input file not found")
# Parse the JSONL file and process each request
input_file_path = file_id_storage.get(batch_request.input_file_id)
with open(input_file_path, "r", encoding="utf-8") as f:
lines = f.readlines()
total_requests = len(lines)
completed_requests = 0
failed_requests = 0
all_ret = []
end_point = batch_storage[batch_id].endpoint
file_request_list = []
all_requests = []
request_ids = []
for line_id, line in enumerate(lines):
request_data = json.loads(line)
file_request_list.append(request_data)
body = request_data["body"]
request_ids.append(f"{batch_id}-req_{line_id}")
# Although streaming is supported for standalone completions, it is not supported in
# batch mode (multiple completions in single request).
if body.get("stream", False):
raise ValueError("Streaming requests are not supported in batch mode")
if end_point == "/v1/chat/completions":
all_requests.append(ChatCompletionRequest(**body))
elif end_point == "/v1/completions":
all_requests.append(CompletionRequest(**body))
if end_point == "/v1/chat/completions":
adapted_request, request = v1_chat_generate_request(
all_requests, tokenizer_manager, request_ids=request_ids
)
elif end_point == "/v1/completions":
adapted_request, request = v1_generate_request(
all_requests, request_ids=request_ids
)
try:
created = int(time.time())
ret = await tokenizer_manager.generate_request(adapted_request).__anext__()
if not isinstance(ret, list):
ret = [ret]
if end_point == "/v1/chat/completions":
responses = v1_chat_generate_response(
request,
ret,
created,
to_file=True,
cache_report=tokenizer_manager.server_args.enable_cache_report,
tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
)
else:
responses = v1_generate_response(
request,
ret,
tokenizer_manager,
created,
to_file=True,
cache_report=tokenizer_manager.server_args.enable_cache_report,
)
except Exception as e:
logger.error(f"error: {get_exception_traceback()}")
responses = []
error_json = {
"id": f"batch_req_{uuid.uuid4()}",
"custom_id": request_data.get("custom_id"),
"response": None,
"error": {"message": str(e)},
}
all_ret.append(error_json)
failed_requests += len(file_request_list)
for idx, response in enumerate(responses):
# the batch_req here can be changed to be named within a batch granularity
response_json = {
"id": f"batch_req_{uuid.uuid4()}",
"custom_id": file_request_list[idx].get("custom_id"),
"response": response,
"error": None,
}
all_ret.append(response_json)
completed_requests += 1
# Write results to a new file
output_file_id = f"backend_result_file-{uuid.uuid4()}"
global storage_dir
output_file_path = os.path.join(storage_dir, f"{output_file_id}.jsonl")
with open(output_file_path, "w", encoding="utf-8") as f:
for ret in all_ret:
f.write(json.dumps(ret) + "\n")
# Update batch response with output file information
retrieve_batch = batch_storage[batch_id]
retrieve_batch.output_file_id = output_file_id
file_id_storage[output_file_id] = output_file_path
file_id_response[output_file_id] = FileResponse(
id=output_file_id,
bytes=os.path.getsize(output_file_path),
created_at=int(time.time()),
filename=f"{output_file_id}.jsonl",
purpose="batch_result",
)
# Update batch status to "completed"
retrieve_batch.status = "completed"
retrieve_batch.completed_at = int(time.time())
retrieve_batch.request_counts = {
"total": total_requests,
"completed": completed_requests,
"failed": failed_requests,
}
except Exception as e:
logger.error(f"error: {e}")
# Update batch status to "failed"
retrieve_batch = batch_storage[batch_id]
retrieve_batch.status = "failed"
retrieve_batch.failed_at = int(time.time())
retrieve_batch.errors = {"message": str(e)}
async def v1_retrieve_batch(batch_id: str):
# Retrieve the batch job from the in-memory storage
batch_response = batch_storage.get(batch_id)
if batch_response is None:
raise HTTPException(status_code=404, detail="Batch not found")
return batch_response
async def v1_cancel_batch(tokenizer_manager, batch_id: str):
# Retrieve the batch job from the in-memory storage
batch_response = batch_storage.get(batch_id)
if batch_response is None:
raise HTTPException(status_code=404, detail="Batch not found")
# Only do cancal when status is "validating" or "in_progress"
if batch_response.status in ["validating", "in_progress"]:
# Start cancelling the batch asynchronously
asyncio.create_task(
cancel_batch(
tokenizer_manager=tokenizer_manager,
batch_id=batch_id,
input_file_id=batch_response.input_file_id,
)
)
# Update batch status to "cancelling"
batch_response.status = "cancelling"
return batch_response
else:
raise HTTPException(
status_code=500,
detail=f"Current status is {batch_response.status}, no need to cancel",
)
async def cancel_batch(tokenizer_manager, batch_id: str, input_file_id: str):
try:
# Update the batch status to "cancelling"
batch_storage[batch_id].status = "cancelling"
# Retrieve the input file content
input_file_request = file_id_request.get(input_file_id)
if not input_file_request:
raise ValueError("Input file not found")
# Parse the JSONL file and process each request
input_file_path = file_id_storage.get(input_file_id)
with open(input_file_path, "r", encoding="utf-8") as f:
lines = f.readlines()
# Cancel requests by request_ids
for line_id in range(len(lines)):
rid = f"{batch_id}-req_{line_id}"
tokenizer_manager.abort_request(rid=rid)
retrieve_batch = batch_storage[batch_id]
retrieve_batch.status = "cancelled"
except Exception as e:
logger.error("error in SGLang:", e)
# Update batch status to "failed"
retrieve_batch = batch_storage[batch_id]
retrieve_batch.status = "failed"
retrieve_batch.failed_at = int(time.time())
retrieve_batch.errors = {"message": str(e)}
async def v1_retrieve_file(file_id: str):
# Retrieve the batch job from the in-memory storage
file_response = file_id_response.get(file_id)
if file_response is None:
raise HTTPException(status_code=404, detail="File not found")
return file_response
async def v1_retrieve_file_content(file_id: str):
file_pth = file_id_storage.get(file_id)
if not file_pth or not os.path.exists(file_pth):
raise HTTPException(status_code=404, detail="File not found")
def iter_file():
with open(file_pth, mode="rb") as file_like:
yield from file_like
return StreamingResponse(iter_file(), media_type="application/octet-stream")
def v1_generate_request(
all_requests: List[CompletionRequest], request_ids: List[str] = None
):
if len(all_requests) > 1:
first_prompt_type = type(all_requests[0].prompt)
for request in all_requests:
assert (
type(request.prompt) is first_prompt_type
), "All prompts must be of the same type in file input settings"
if request.n > 1:
raise ValueError(
"Parallel sampling is not supported for completions from files"
)
prompts = []
sampling_params_list = []
return_logprobs = []
logprob_start_lens = []
top_logprobs_nums = []
lora_paths = []
return_hidden_states = []
for request in all_requests:
# NOTE: with openai API, the prompt's logprobs are always not computed
if request.echo and request.logprobs:
logger.warning(
"Echo is not compatible with logprobs. "
"To compute logprobs of input prompt, please use the native /generate API."
)
prompt = request.prompt
if is_completion_template_defined():
prompt = generate_completion_prompt_from_request(request)
prompts.append(prompt)
lora_paths.append(request.lora_path)
if request.echo and request.logprobs:
current_logprob_start_len = 0
else:
current_logprob_start_len = -1
sampling_params_list.append(
{
"temperature": request.temperature,
"max_new_tokens": request.max_tokens,
"min_new_tokens": request.min_tokens,
"stop": request.stop,
"stop_token_ids": request.stop_token_ids,
"top_p": request.top_p,
"top_k": request.top_k,
"min_p": request.min_p,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
"repetition_penalty": request.repetition_penalty,
"regex": request.regex,
"json_schema": request.json_schema,
"ebnf": request.ebnf,
"n": request.n,
"no_stop_trim": request.no_stop_trim,
"ignore_eos": request.ignore_eos,
"skip_special_tokens": request.skip_special_tokens,
"logit_bias": request.logit_bias,
}
)
return_logprobs.append(request.logprobs is not None)
logprob_start_lens.append(current_logprob_start_len)
top_logprobs_nums.append(
request.logprobs if request.logprobs is not None else 0
)
return_hidden_states.append(request.return_hidden_states)
if len(all_requests) == 1:
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
prompt_kwargs = {"text": prompts[0]}
else:
prompt_kwargs = {"input_ids": prompts[0]}
sampling_params_list = sampling_params_list[0]
return_logprobs = return_logprobs[0]
logprob_start_lens = logprob_start_lens[0]
top_logprobs_nums = top_logprobs_nums[0]
lora_paths = lora_paths[0]
return_hidden_states = return_hidden_states[0]
else:
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
prompt_kwargs = {"text": prompts}
else:
prompt_kwargs = {"input_ids": prompts}
adapted_request = GenerateReqInput(
**prompt_kwargs,
sampling_params=sampling_params_list,
return_logprob=return_logprobs,
top_logprobs_num=top_logprobs_nums,
logprob_start_len=logprob_start_lens,
return_text_in_logprobs=True,
stream=all_requests[0].stream,
rid=request_ids,
lora_path=lora_paths,
return_hidden_states=return_hidden_states,
bootstrap_host=all_requests[0].bootstrap_host,
bootstrap_port=all_requests[0].bootstrap_port,
bootstrap_room=all_requests[0].bootstrap_room,
)
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
def v1_generate_response(
request, ret, tokenizer_manager, created, to_file=False, cache_report=False
):
choices = []
echo = False
if (not isinstance(request, list)) and request.echo:
# TODO: handle the case prompt is token ids
if isinstance(request.prompt, list) and isinstance(request.prompt[0], str):
# for the case of multiple str prompts
prompts = request.prompt
elif isinstance(request.prompt, list) and isinstance(request.prompt[0], list):
# for the case of multiple token ids prompts
prompts = [
tokenizer_manager.tokenizer.decode(prompt, skip_special_tokens=True)
for prompt in request.prompt
]
elif isinstance(request.prompt, list) and isinstance(request.prompt[0], int):
# for the case of single token ids prompt
prompts = [
tokenizer_manager.tokenizer.decode(
request.prompt, skip_special_tokens=True
)
]
else:
# for the case of single str prompt
prompts = [request.prompt]
echo = True
for idx, ret_item in enumerate(ret):
text = ret_item["text"]
if isinstance(request, list) and request[idx].echo:
echo = True
text = request[idx].prompt + text
if echo and not isinstance(request, list):
prompt_index = idx // request.n
text = prompts[prompt_index] + text
logprobs = False
if isinstance(request, list) and request[idx].logprobs is not None:
logprobs = True
elif (not isinstance(request, list)) and request.logprobs is not None:
logprobs = True
if logprobs:
if echo:
input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"]
input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"]
else:
input_token_logprobs = None
input_top_logprobs = None
logprobs = to_openai_style_logprobs(
input_token_logprobs=input_token_logprobs,
input_top_logprobs=input_top_logprobs,
output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"],
output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
)
else:
logprobs = None
hidden_states = None
if isinstance(request, list) and request[idx].return_hidden_states:
hidden_states = ret_item["meta_info"].get("hidden_states", None)
elif (not isinstance(request, list)) and request.return_hidden_states:
hidden_states = ret_item["meta_info"].get("hidden_states", None)
if hidden_states is not None:
hidden_states = (
hidden_states[-1] if hidden_states and len(hidden_states) > 1 else []
)
finish_reason = ret_item["meta_info"]["finish_reason"]
if to_file:
# to make the choice data json serializable
choice_data = {
"index": 0,
"text": text,
"logprobs": logprobs,
"finish_reason": finish_reason["type"] if finish_reason else None,
"matched_stop": (
finish_reason["matched"]
if finish_reason and "matched" in finish_reason
else None
),
}
if hidden_states is not None:
choice_data["hidden_states"] = hidden_states
else:
choice_data = CompletionResponseChoice(
index=idx,
text=text,
logprobs=logprobs,
finish_reason=finish_reason["type"] if finish_reason else None,
matched_stop=(
finish_reason["matched"]
if finish_reason and "matched" in finish_reason
else None
),
hidden_states=hidden_states,
)
choices.append(choice_data)
if to_file:
responses = []
for i, choice in enumerate(choices):
response = {
"status_code": 200,
"request_id": ret[i]["meta_info"]["id"],
"body": {
# remain the same but if needed we can change that
"id": ret[i]["meta_info"]["id"],
"object": "text_completion",
"created": created,
"model": request[i].model,
"choices": choice,
"usage": {
"prompt_tokens": ret[i]["meta_info"]["prompt_tokens"],
"completion_tokens": ret[i]["meta_info"]["completion_tokens"],
"total_tokens": ret[i]["meta_info"]["prompt_tokens"]
+ ret[i]["meta_info"]["completion_tokens"],
},
"system_fingerprint": None,
},
}
responses.append(response)
return responses
else:
prompt_tokens = sum(
ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n)
)
completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
cached_tokens = sum(item["meta_info"].get("cached_tokens", 0) for item in ret)
response = CompletionResponse(
id=ret[0]["meta_info"]["id"],
model=request.model,
created=created,
choices=choices,
usage=UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
prompt_tokens_details=(
{"cached_tokens": cached_tokens} if cache_report else None
),
),
)
return response
async def v1_completions(tokenizer_manager, raw_request: Request):
try:
request_json = await raw_request.json()
except Exception as e:
return create_error_response("Invalid request body, error: ", str(e))
all_requests = [CompletionRequest(**request_json)]
created = int(time.time())
adapted_request, request = v1_generate_request(all_requests)
if adapted_request.stream:
async def generate_stream_resp():
stream_buffers = {}
n_prev_tokens = {}
prompt_tokens = {}
completion_tokens = {}
cached_tokens = {}
hidden_states = {}
try:
async for content in tokenizer_manager.generate_request(
adapted_request, raw_request
):
index = content.get("index", 0)
stream_buffer = stream_buffers.get(index, "")
n_prev_token = n_prev_tokens.get(index, 0)
text = content["text"]
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
completion_tokens[index] = content["meta_info"]["completion_tokens"]
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
hidden_states[index] = content["meta_info"].get(
"hidden_states", None
) or hidden_states.get(index)
if not stream_buffer: # The first chunk
if request.echo:
if isinstance(request.prompt, str):
# for the case of single str prompts
prompts = request.prompt
elif isinstance(request.prompt, list):
if isinstance(request.prompt[0], str):
# for the case of multiple str prompts
prompts = request.prompt[index // request.n]
elif isinstance(request.prompt[0], int):
# for the case of single token ids prompt
prompts = tokenizer_manager.tokenizer.decode(
request.prompt, skip_special_tokens=True
)
elif isinstance(request.prompt[0], list) and isinstance(
request.prompt[0][0], int
):
# for the case of multiple token ids prompts
prompts = tokenizer_manager.tokenizer.decode(
request.prompt[index // request.n],
skip_special_tokens=True,
)
# Prepend prompt in response text.
text = prompts + text
if request.logprobs is not None:
# The first chunk and echo is enabled.
if not stream_buffer and request.echo:
input_token_logprobs = content["meta_info"][
"input_token_logprobs"
]
input_top_logprobs = content["meta_info"][
"input_top_logprobs"
]
else:
input_token_logprobs = None
input_top_logprobs = None
logprobs = to_openai_style_logprobs(
input_token_logprobs=input_token_logprobs,
input_top_logprobs=input_top_logprobs,
output_token_logprobs=content["meta_info"][
"output_token_logprobs"
][n_prev_token:],
output_top_logprobs=content["meta_info"][
"output_top_logprobs"
][n_prev_token:],
)
n_prev_token = len(
content["meta_info"]["output_token_logprobs"]
)
else:
logprobs = None
delta = text[len(stream_buffer) :]
stream_buffer = stream_buffer + delta
finish_reason = content["meta_info"]["finish_reason"]
choice_data = CompletionResponseStreamChoice(
index=index,
text=delta,
logprobs=logprobs,
finish_reason=finish_reason["type"] if finish_reason else None,
matched_stop=(
finish_reason["matched"]
if finish_reason and "matched" in finish_reason
else None
),
)
chunk = CompletionStreamResponse(
id=content["meta_info"]["id"],
created=created,
object="text_completion",
choices=[choice_data],
model=request.model,
)
stream_buffers[index] = stream_buffer
n_prev_tokens[index] = n_prev_token
yield f"data: {chunk.model_dump_json()}\n\n"
if request.return_hidden_states and hidden_states:
for index, choice_hidden_states in hidden_states.items():
last_token_hidden_states = (
choice_hidden_states[-1]
if choice_hidden_states and len(choice_hidden_states) > 1
else []
)
hidden_states_chunk = CompletionStreamResponse(
id=content["meta_info"]["id"],
created=created,
choices=[
CompletionResponseStreamChoice(
text="",
index=index,
hidden_states=last_token_hidden_states,
finish_reason=None,
)
],
model=request.model,
)
yield f"data: {hidden_states_chunk.model_dump_json()}\n\n"
if request.stream_options and request.stream_options.include_usage:
total_prompt_tokens = sum(
tokens
for i, tokens in prompt_tokens.items()
if i % request.n == 0
)
total_completion_tokens = sum(
tokens for tokens in completion_tokens.values()
)
cache_report = tokenizer_manager.server_args.enable_cache_report
if cache_report:
cached_tokens_sum = sum(
tokens for tokens in cached_tokens.values()
)
prompt_tokens_details = {"cached_tokens": cached_tokens_sum}
else:
prompt_tokens_details = None
usage = UsageInfo(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens,
prompt_tokens_details=prompt_tokens_details,
)
final_usage_chunk = CompletionStreamResponse(
id=content["meta_info"]["id"],
created=created,
choices=[],
model=request.model,
usage=usage,
)
final_usage_data = final_usage_chunk.model_dump_json(
exclude_none=True
)
yield f"data: {final_usage_data}\n\n"
except ValueError as e:
error = create_streaming_error_response(str(e))
yield f"data: {error}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
generate_stream_resp(),
media_type="text/event-stream",
background=tokenizer_manager.create_abort_task(adapted_request),
)
# Non-streaming response.
try:
ret = await tokenizer_manager.generate_request(
adapted_request, raw_request
).__anext__()
except ValueError as e:
return create_error_response(str(e))
if not isinstance(ret, list):
ret = [ret]
response = v1_generate_response(
request,
ret,
tokenizer_manager,
created,
cache_report=tokenizer_manager.server_args.enable_cache_report,
)
return response
def _get_enable_thinking_from_request(request_obj):
"""Extracts the 'enable_thinking' flag from request chat_template_kwargs.
Args:
request_obj: The request object (or an item from a list of requests).
Returns:
The boolean value of 'enable_thinking' if found and not True, otherwise True.
"""
if (
hasattr(request_obj, "chat_template_kwargs")
and request_obj.chat_template_kwargs
and request_obj.chat_template_kwargs.get("enable_thinking") is not None
):
return request_obj.chat_template_kwargs.get("enable_thinking")
return True
def v1_chat_generate_request(
all_requests: List[ChatCompletionRequest],
tokenizer_manager,
request_ids: List[str] = None,
):
input_ids = []
prompts = []
sampling_params_list = []
image_data_list = []
audio_data_list = []
return_logprobs = []
logprob_start_lens = []
top_logprobs_nums = []
modalities_list = []
lora_paths = []
return_hidden_states = []
# NOTE: with openai API, the prompt's logprobs are always not computed
is_multimodal = tokenizer_manager.model_config.is_multimodal
for request in all_requests:
# Prep the data needed for the underlying GenerateReqInput:
# - prompt: The full prompt string.
# - stop: Custom stop tokens.
# - image_data: None or a list of image strings (URLs or base64 strings).
# - audio_data: None or a list of audio strings (URLs).
# None skips any image processing in GenerateReqInput.
tool_call_constraint = None
prompt = ""
prompt_ids = []
if not isinstance(request.messages, str):
# Apply chat template and its stop strings.
tools = None
if request.tools and request.tool_choice != "none":
request.skip_special_tokens = False
if not isinstance(request.tool_choice, str):
tools = [
item.function.model_dump()
for item in request.tools
if item.function.name == request.tool_choice.function.name
]
else:
tools = [item.function.model_dump() for item in request.tools]
tool_call_parser = tokenizer_manager.server_args.tool_call_parser
parser = FunctionCallParser(request.tools, tool_call_parser)
tool_call_constraint = parser.get_structure_constraint(
request.tool_choice
)
if chat_template_name is None:
openai_compatible_messages = []
image_data = []
audio_data = []
modalities = []
# Detect template content format by analyzing the jinja template (cached globally)
global _cached_chat_template, _cached_template_format
current_template = tokenizer_manager.tokenizer.chat_template
if current_template != _cached_chat_template:
# Template changed or first time - analyze it
_cached_chat_template = current_template
_cached_template_format = detect_template_content_format(
current_template
)
logger.info(
f"Detected chat template content format: {_cached_template_format}"
)
template_content_format = _cached_template_format
for message in request.messages:
if message.content is None:
message.content = ""
msg_dict = message.model_dump()
# Process content based on detected template format
processed_msg = process_content_for_template_format(
msg_dict,
template_content_format,
image_data,
audio_data,
modalities,
)
openai_compatible_messages.append(processed_msg)
# Handle assistant prefix for continue_final_message
if (
openai_compatible_messages
and openai_compatible_messages[-1]["role"] == "assistant"
):
if request.continue_final_message:
# Remove the final assistant message so its content can be continued.
assistant_prefix = openai_compatible_messages[-1]["content"]
openai_compatible_messages = openai_compatible_messages[:-1]
else:
assistant_prefix = None
else:
assistant_prefix = None
try:
prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
openai_compatible_messages,
tokenize=True,
add_generation_prompt=True,
tools=tools,
**(
request.chat_template_kwargs
if request.chat_template_kwargs
else {}
),
)
except:
# This except branch will be triggered when the chosen model
# has a different tools input format that is not compatible
# with openAI's apply_chat_template tool_call format, like Mistral.
tools = [t if "function" in t else {"function": t} for t in tools]
prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
openai_compatible_messages,
tokenize=True,
add_generation_prompt=True,
tools=tools,
**(
request.chat_template_kwargs
if request.chat_template_kwargs
else {}
),
)
if assistant_prefix:
encoded = tokenizer_manager.tokenizer.encode(assistant_prefix)
if (
encoded
and encoded[0] == tokenizer_manager.tokenizer.bos_token_id
):
encoded = encoded[1:]
prompt_ids += encoded
if is_multimodal:
prompt = tokenizer_manager.tokenizer.decode(prompt_ids)
stop = request.stop
image_data = image_data if image_data else None
audio_data = audio_data if audio_data else None
modalities = modalities if modalities else []
else:
conv = generate_chat_conv(request, chat_template_name)
# If we should continue the final assistant message, adjust the conversation.
if (
request.continue_final_message
and request.messages
and request.messages[-1].role == "assistant"
):
# Remove the auto-added blank assistant turn, if present.
if conv.messages and conv.messages[-1][1] is None:
conv.messages.pop()
# Rebuild the prompt from the conversation.
prompt = conv.get_prompt()
# Strip any trailing stop tokens or separators that indicate end-of-assistant.
if isinstance(conv.stop_str, list):
for stop_token in conv.stop_str:
if prompt.endswith(stop_token):
prompt = prompt[: -len(stop_token)]
elif isinstance(conv.stop_str, str) and prompt.endswith(
conv.stop_str
):
prompt = prompt[: -len(conv.stop_str)]
if conv.sep and prompt.endswith(conv.sep):
prompt = prompt[: -len(conv.sep)]
if getattr(conv, "sep2", None) and prompt.endswith(conv.sep2):
prompt = prompt[: -len(conv.sep2)]
else:
prompt = conv.get_prompt()
image_data = conv.image_data
audio_data = conv.audio_data
modalities = conv.modalities
stop = conv.stop_str or [] if not request.ignore_eos else []
if request.stop:
if isinstance(request.stop, str):
stop.append(request.stop)
else:
stop.extend(request.stop)
if not is_multimodal:
prompt_ids = tokenizer_manager.tokenizer.encode(prompt)
else:
# Use the raw prompt and stop strings if the messages is already a string.
prompt_ids = request.messages
stop = request.stop
image_data = None
audio_data = None
modalities = []
prompt = request.messages
input_ids.append(prompt_ids)
return_logprobs.append(request.logprobs)
logprob_start_lens.append(-1)
top_logprobs_nums.append(request.top_logprobs or 0)
lora_paths.append(request.lora_path)
prompts.append(prompt)
sampling_params = {
"temperature": request.temperature,
"max_new_tokens": request.max_tokens or request.max_completion_tokens,
"min_new_tokens": request.min_tokens,
"stop": stop,
"stop_token_ids": request.stop_token_ids,
"top_p": request.top_p,
"top_k": request.top_k,
"min_p": request.min_p,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
"repetition_penalty": request.repetition_penalty,
"regex": request.regex,
"ebnf": request.ebnf,
"n": request.n,
"no_stop_trim": request.no_stop_trim,
"ignore_eos": request.ignore_eos,
"skip_special_tokens": request.skip_special_tokens,
"logit_bias": request.logit_bias,
}
if request.response_format and request.response_format.type == "json_schema":
sampling_params["json_schema"] = convert_json_schema_to_str(
request.response_format.json_schema.schema_
)
elif request.response_format and request.response_format.type == "json_object":
sampling_params["json_schema"] = '{"type": "object"}'
elif (
request.response_format and request.response_format.type == "structural_tag"
):
sampling_params["structural_tag"] = convert_json_schema_to_str(
request.response_format.model_dump(by_alias=True)
)
# Check if there are already existing output constraints
has_existing_constraints = (
sampling_params.get("regex")
or sampling_params.get("ebnf")
or sampling_params.get("structural_tag")
or sampling_params.get("json_schema")
)
if tool_call_constraint and has_existing_constraints:
logger.warning("Constrained decoding is not compatible with tool calls.")
elif tool_call_constraint:
constraint_type, constraint_value = tool_call_constraint
if constraint_type == "structural_tag":
sampling_params[constraint_type] = convert_json_schema_to_str(
constraint_value.model_dump(by_alias=True)
)
else:
sampling_params[constraint_type] = constraint_value
sampling_params_list.append(sampling_params)
image_data_list.append(image_data)
audio_data_list.append(audio_data)
modalities_list.append(modalities)
return_hidden_states.append(request.return_hidden_states)
if len(all_requests) == 1:
if is_multimodal:
# processor will need text input
prompt_kwargs = {"text": prompts[0]}
else:
if isinstance(input_ids[0], str):
prompt_kwargs = {"text": input_ids[0]}
else:
prompt_kwargs = {"input_ids": input_ids[0]}
sampling_params_list = sampling_params_list[0]
image_data_list = image_data_list[0]
audio_data_list = audio_data_list[0]
return_logprobs = return_logprobs[0]
logprob_start_lens = logprob_start_lens[0]
top_logprobs_nums = top_logprobs_nums[0]
modalities_list = modalities_list[0]
lora_paths = lora_paths[0]
request_ids = request_ids[0]
return_hidden_states = return_hidden_states[0]
else:
if tokenizer_manager.model_config.is_multimodal:
# processor will need text input
prompt_kwargs = {"text": prompts}
else:
if isinstance(input_ids[0], str):
prompt_kwargs = {"text": input_ids}
else:
prompt_kwargs = {"input_ids": input_ids}
adapted_request = GenerateReqInput(
**prompt_kwargs,
image_data=image_data_list,
audio_data=audio_data_list,
sampling_params=sampling_params_list,
return_logprob=return_logprobs,
logprob_start_len=logprob_start_lens,
top_logprobs_num=top_logprobs_nums,
stream=all_requests[0].stream,
return_text_in_logprobs=True,
rid=request_ids,
modalities=modalities_list,
lora_path=lora_paths,
bootstrap_host=all_requests[0].bootstrap_host,
bootstrap_port=all_requests[0].bootstrap_port,
bootstrap_room=all_requests[0].bootstrap_room,
return_hidden_states=return_hidden_states,
)
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
def v1_chat_generate_response(
request,
ret,
created,
to_file=False,
cache_report=False,
tool_call_parser=None,
reasoning_parser=None,
):
choices = []
for idx, ret_item in enumerate(ret):
logprobs = False
if isinstance(request, list) and request[idx].logprobs:
logprobs = True
elif (not isinstance(request, list)) and request.logprobs:
logprobs = True
if logprobs:
logprobs = to_openai_style_logprobs(
output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"],
output_top_logprobs=ret_item["meta_info"].get(
"output_top_logprobs", None
),
)
token_logprobs = []
for token_idx, (token, logprob) in enumerate(
zip(logprobs.tokens, logprobs.token_logprobs)
):
token_bytes = list(token.encode("utf-8"))
top_logprobs = []
if logprobs.top_logprobs:
for top_token, top_logprob in logprobs.top_logprobs[
token_idx
].items():
top_token_bytes = list(top_token.encode("utf-8"))
top_logprobs.append(
TopLogprob(
token=top_token,
bytes=top_token_bytes,
logprob=top_logprob,
)
)
token_logprobs.append(
ChatCompletionTokenLogprob(
token=token,
bytes=token_bytes,
logprob=logprob,
top_logprobs=top_logprobs,
)
)
choice_logprobs = ChoiceLogprobs(content=token_logprobs)
else:
choice_logprobs = None
if isinstance(request, list) and request[idx].return_hidden_states:
include_hidden_states = True
elif not isinstance(request, list) and request.return_hidden_states:
include_hidden_states = True
else:
include_hidden_states = False
if include_hidden_states and ret_item["meta_info"].get("hidden_states", None):
hidden_states = ret_item["meta_info"]["hidden_states"]
hidden_states = (
hidden_states[-1] if hidden_states and len(hidden_states) > 1 else []
)
else:
hidden_states = None
finish_reason = ret_item["meta_info"]["finish_reason"]
tool_calls = None
text = ret_item["text"]
if isinstance(request, list):
tool_choice = request[idx].tool_choice
tools = request[idx].tools
separate_reasoning = request[idx].separate_reasoning
enable_thinking = _get_enable_thinking_from_request(request[idx])
else:
tool_choice = request.tool_choice
tools = request.tools
separate_reasoning = request.separate_reasoning
enable_thinking = _get_enable_thinking_from_request(request)
reasoning_text = None
if reasoning_parser and separate_reasoning and enable_thinking:
try:
parser = ReasoningParser(
model_type=reasoning_parser, stream_reasoning=False
)
reasoning_text, text = parser.parse_non_stream(text)
except Exception as e:
logger.error(f"Exception: {e}")
return create_error_response(
HTTPStatus.BAD_REQUEST,
"Failed to parse reasoning related info to json format!",
)
if tool_choice != "none" and tools:
parser = FunctionCallParser(tools, tool_call_parser)
if parser.has_tool_call(text):
if finish_reason["type"] == "stop":
finish_reason["type"] = "tool_calls"
finish_reason["matched"] = None
try:
text, call_info_list = parser.parse_non_stream(text)
tool_calls = [
ToolCall(
id=f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}",
function=FunctionResponse(
name=call_info.name, arguments=call_info.parameters
),
)
for call_info in call_info_list
]
except Exception as e:
logger.error(f"Exception: {e}")
return create_error_response(
HTTPStatus.BAD_REQUEST,
"Failed to parse fc related info to json format!",
)
if to_file:
# to make the choice data json serializable
choice_data = {
"index": 0,
"message": {
"role": "assistant",
"content": text if text else None,
"tool_calls": tool_calls,
"reasoning_content": reasoning_text if reasoning_text else None,
},
"logprobs": choice_logprobs.model_dump() if choice_logprobs else None,
"finish_reason": finish_reason["type"] if finish_reason else None,
"matched_stop": (
finish_reason["matched"]
if finish_reason and "matched" in finish_reason
else None
),
}
if hidden_states is not None:
choice_data["hidden_states"] = hidden_states
else:
choice_data = ChatCompletionResponseChoice(
index=idx,
message=ChatMessage(
role="assistant",
content=text if text else None,
tool_calls=tool_calls,
reasoning_content=reasoning_text if reasoning_text else None,
),
logprobs=choice_logprobs,
finish_reason=finish_reason["type"] if finish_reason else None,
matched_stop=(
finish_reason["matched"]
if finish_reason and "matched" in finish_reason
else None
),
hidden_states=hidden_states,
)
choices.append(choice_data)
if to_file:
responses = []
for i, choice in enumerate(choices):
response = {
"status_code": 200,
"request_id": ret[i]["meta_info"]["id"],
"body": {
# remain the same but if needed we can change that
"id": ret[i]["meta_info"]["id"],
"object": "chat.completion",
"created": created,
"model": (
request[i].model if isinstance(request, list) else request.model
),
"choices": choice,
"usage": {
"prompt_tokens": ret[i]["meta_info"]["prompt_tokens"],
"completion_tokens": ret[i]["meta_info"]["completion_tokens"],
"total_tokens": ret[i]["meta_info"]["prompt_tokens"]
+ ret[i]["meta_info"]["completion_tokens"],
},
"system_fingerprint": None,
},
}
responses.append(response)
return responses
else:
prompt_tokens = sum(
ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n)
)
completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
cached_tokens = sum(item["meta_info"].get("cached_tokens", 0) for item in ret)
response = ChatCompletionResponse(
id=ret[0]["meta_info"]["id"],
created=created,
model=request.model,
choices=choices,
usage=UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
prompt_tokens_details=(
{"cached_tokens": cached_tokens} if cache_report else None
),
),
)
return response
async def v1_chat_completions(
tokenizer_manager, raw_request: Request, cache_report=False
):
try:
request_json = await raw_request.json()
except Exception as e:
return create_error_response("Invalid request body, error: ", str(e))
all_requests = [ChatCompletionRequest(**request_json)]
created = int(time.time())
adapted_request, request = v1_chat_generate_request(
all_requests, tokenizer_manager, request_ids=[all_requests[0].rid]
)
if adapted_request.stream:
parser_dict = {}
reasoning_parser_dict = {}
async def generate_stream_resp():
tool_index_previous = -1
is_firsts = {}
stream_buffers = {}
n_prev_tokens = {}
prompt_tokens = {}
completion_tokens = {}
cached_tokens = {}
hidden_states = {}
try:
async for content in tokenizer_manager.generate_request(
adapted_request, raw_request
):
index = content.get("index", 0)
text = content["text"]
hidden_states[index] = content["meta_info"].get(
"hidden_states", None
) or hidden_states.get(index)
is_first = is_firsts.get(index, True)
stream_buffer = stream_buffers.get(index, "")
n_prev_token = n_prev_tokens.get(index, 0)
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
completion_tokens[index] = content["meta_info"]["completion_tokens"]
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
if request.logprobs:
logprobs = to_openai_style_logprobs(
output_token_logprobs=content["meta_info"][
"output_token_logprobs"
][n_prev_token:],
output_top_logprobs=content["meta_info"].get(
"output_top_logprobs", []
)[n_prev_token:],
)
n_prev_token = len(
content["meta_info"]["output_token_logprobs"]
)
token_logprobs = []
for token, logprob in zip(
logprobs.tokens, logprobs.token_logprobs
):
token_bytes = list(token.encode("utf-8"))
top_logprobs = []
if logprobs.top_logprobs:
for top_token, top_logprob in logprobs.top_logprobs[
0
].items():
top_token_bytes = list(top_token.encode("utf-8"))
top_logprobs.append(
TopLogprob(
token=top_token,
bytes=top_token_bytes,
logprob=top_logprob,
)
)
token_logprobs.append(
ChatCompletionTokenLogprob(
token=token,
bytes=token_bytes,
logprob=logprob,
top_logprobs=top_logprobs,
)
)
choice_logprobs = ChoiceLogprobs(content=token_logprobs)
else:
choice_logprobs = None
finish_reason = content["meta_info"]["finish_reason"]
finish_reason_type = (
finish_reason["type"] if finish_reason else None
)
if is_first:
# First chunk with role
is_first = False
delta = DeltaMessage(role="assistant")
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=delta,
finish_reason=finish_reason_type,
matched_stop=(
finish_reason["matched"]
if finish_reason and "matched" in finish_reason
else None
),
logprobs=choice_logprobs,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=created,
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n"
text = content["text"]
delta = text[len(stream_buffer) :]
new_stream_buffer = stream_buffer + delta
enable_thinking = _get_enable_thinking_from_request(request)
if (
tokenizer_manager.server_args.reasoning_parser
and request.separate_reasoning
and enable_thinking
):
if index not in reasoning_parser_dict:
reasoning_parser_dict[index] = ReasoningParser(
tokenizer_manager.server_args.reasoning_parser,
request.stream_reasoning,
)
reasoning_parser = reasoning_parser_dict[index]
reasoning_text, delta = reasoning_parser.parse_stream_chunk(
delta
)
if reasoning_text:
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(
reasoning_content=(
reasoning_text if reasoning_text else None
)
),
finish_reason=finish_reason_type,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=created,
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n"
if (delta and len(delta) == 0) or not delta:
stream_buffers[index] = new_stream_buffer
is_firsts[index] = is_first
n_prev_tokens[index] = n_prev_token
continue
if request.tool_choice != "none" and request.tools:
if index not in parser_dict:
parser_dict[index] = FunctionCallParser(
tools=request.tools,
tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
)
parser = parser_dict[index]
# parse_increment => returns (normal_text, calls)
normal_text, calls = parser.parse_stream_chunk(delta)
# 1) if there's normal_text, output it as normal content
if normal_text:
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(
content=normal_text if normal_text else None
),
finish_reason=finish_reason_type,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=created,
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n"
# 2) if we found calls, we output them as separate chunk(s)
for call_item in calls:
tool_index_current = call_item.tool_index
# transform call_item -> FunctionResponse + ToolCall
if finish_reason_type == "stop":
latest_delta_len = 0
if isinstance(call_item.parameters, str):
latest_delta_len = len(call_item.parameters)
expected_call = json.dumps(
parser.detector.prev_tool_call_arr[index].get(
"arguments", {}
),
ensure_ascii=False,
)
actual_call = parser.detector.streamed_args_for_tool[
index
]
if latest_delta_len > 0:
actual_call = actual_call[:-latest_delta_len]
remaining_call = expected_call.replace(
actual_call, "", 1
)
call_item.parameters = remaining_call
finish_reason_type = "tool_calls"
tool_call = ToolCall(
id=(
f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}"
if tool_index_previous != tool_index_current
else None
),
index=call_item.tool_index,
function=FunctionResponse(
name=call_item.name,
arguments=call_item.parameters,
),
)
tool_index_previous = tool_index_current
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(tool_calls=[tool_call]),
finish_reason=(
None
if request.stream_options
and request.stream_options.include_usage
else finish_reason_type
), # additional chunk will be return
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=created,
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n"
stream_buffers[index] = new_stream_buffer
is_firsts[index] = is_first
n_prev_tokens[index] = n_prev_token
else:
# No tool calls => just treat this as normal text
if delta or not (
request.stream_options
and request.stream_options.include_usage
):
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(content=delta if delta else None),
finish_reason=(
None
if request.stream_options
and request.stream_options.include_usage
else finish_reason_type
),
matched_stop=(
finish_reason["matched"]
if finish_reason and "matched" in finish_reason
else None
),
logprobs=choice_logprobs,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=created,
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n"
stream_buffers[index] = new_stream_buffer
is_firsts[index] = is_first
n_prev_tokens[index] = n_prev_token
if finish_reason_type == "stop" and request.tool_choice != "none":
parser = FunctionCallParser(
tools=request.tools,
tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
)
if parser.has_tool_call(new_stream_buffer):
# if the stream ends with empty string after tool calls
finish_reason_type = "tool_calls"
if request.stream_options and request.stream_options.include_usage:
total_prompt_tokens = sum(
tokens
for i, tokens in prompt_tokens.items()
if i % request.n == 0
)
total_completion_tokens = sum(
tokens for tokens in completion_tokens.values()
)
cache_report = tokenizer_manager.server_args.enable_cache_report
if cache_report:
cached_tokens_sum = sum(
tokens for tokens in cached_tokens.values()
)
prompt_tokens_details = {"cached_tokens": cached_tokens_sum}
else:
prompt_tokens_details = None
usage = UsageInfo(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens,
prompt_tokens_details=prompt_tokens_details,
)
else:
usage = None
if request.return_hidden_states and hidden_states:
for index, choice_hidden_states in hidden_states.items():
last_token_hidden_states = (
choice_hidden_states[-1]
if choice_hidden_states and len(choice_hidden_states) > 1
else []
)
hidden_states_chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=created,
choices=[
ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(
hidden_states=last_token_hidden_states
),
finish_reason=finish_reason_type,
)
],
model=request.model,
)
yield f"data: {hidden_states_chunk.model_dump_json()}\n\n"
final_usage_chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=created,
choices=[
ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(),
finish_reason=finish_reason_type,
)
],
model=request.model,
usage=usage,
)
yield f"data: {final_usage_chunk.model_dump_json()}\n\n"
except ValueError as e:
error = create_streaming_error_response(str(e))
yield f"data: {error}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
generate_stream_resp(),
media_type="text/event-stream",
background=tokenizer_manager.create_abort_task(adapted_request),
)
# Non-streaming response.
try:
ret = await tokenizer_manager.generate_request(
adapted_request, raw_request
).__anext__()
except ValueError as e:
return create_error_response(str(e))
if not isinstance(ret, list):
ret = [ret]
response = v1_chat_generate_response(
request,
ret,
created,
cache_report=tokenizer_manager.server_args.enable_cache_report,
tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
reasoning_parser=tokenizer_manager.server_args.reasoning_parser,
)
return response
def v1_embedding_request(all_requests, tokenizer_manager):
prompts = []
sampling_params_list = []
first_prompt_type = type(all_requests[0].input)
for request in all_requests:
prompt = request.input
# Check for empty/whitespace string
prompt = _validate_prompt(request.input)
assert (
type(prompt) is first_prompt_type
), "All prompts must be of the same type in file input settings"
prompts.append(prompt)
if len(all_requests) == 1:
prompt = prompts[0]
if isinstance(prompt, str) or isinstance(prompt[0], str):
prompt_kwargs = {"text": prompt}
elif isinstance(prompt, list) and isinstance(
prompt[0], MultimodalEmbeddingInput
):
texts = []
images = []
for item in prompt:
# TODO simply use padding for text, we should use a better way to handle this
texts.append(item.text if item.text is not None else "padding")
images.append(item.image if item.image is not None else None)
generate_prompts = []
if chat_template_name is not None:
convs = generate_embedding_convs(texts, images, chat_template_name)
for conv in convs:
generate_prompts.append(conv.get_prompt())
else:
generate_prompts = texts
if len(generate_prompts) == 1:
prompt_kwargs = {"text": generate_prompts[0], "image_data": images[0]}
else:
prompt_kwargs = {"text": generate_prompts, "image_data": images}
else:
prompt_kwargs = {"input_ids": prompt}
request_ids = all_requests[0].rid
else:
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
prompt_kwargs = {"text": prompts}
elif isinstance(prompts[0], list) and isinstance(
prompts[0][0], MultimodalEmbeddingInput
):
# TODO: multiple requests
raise NotImplementedError(
"Multiple requests with multimodal inputs are not supported yet"
)
else:
prompt_kwargs = {"input_ids": prompts}
request_ids = [req.rid for req in all_requests]
adapted_request = EmbeddingReqInput(
rid=request_ids,
**prompt_kwargs,
)
if len(all_requests) == 1:
return adapted_request, all_requests[0]
return adapted_request, all_requests
def v1_embedding_response(ret, model_path, to_file=False):
embedding_objects = []
prompt_tokens = 0
for idx, ret_item in enumerate(ret):
embedding_objects.append(
EmbeddingObject(
embedding=ret[idx]["embedding"],
index=idx,
)
)
prompt_tokens += ret[idx]["meta_info"]["prompt_tokens"]
return EmbeddingResponse(
data=embedding_objects,
model=model_path,
usage=UsageInfo(
prompt_tokens=prompt_tokens,
total_tokens=prompt_tokens,
),
)
async def v1_embeddings(tokenizer_manager, raw_request: Request):
try:
request_json = await raw_request.json()
except Exception as e:
return create_error_response("Invalid request body, error: ", str(e))
all_requests = [EmbeddingRequest(**request_json)]
adapted_request, request = v1_embedding_request(all_requests, tokenizer_manager)
try:
ret = await tokenizer_manager.generate_request(
adapted_request, raw_request
).__anext__()
except ValueError as e:
return create_error_response(str(e))
if not isinstance(ret, list):
ret = [ret]
response = v1_embedding_response(ret, tokenizer_manager.model_path)
return response
def v1_rerank_request(obj: V1RerankReqInput):
if obj.query is None:
raise ValueError("query is required")
if obj.documents is None or len(obj.documents) == 0:
raise ValueError("documents is required")
pairs = []
for doc in obj.documents:
pairs.append([obj.query, doc])
adapted_request = EmbeddingReqInput(
text=pairs,
is_cross_encoder_request=True,
)
return adapted_request
def v1_rerank_response(ret, obj: V1RerankReqInput):
response = []
for idx, ret_item in enumerate(ret):
response.append(
RerankResponse(
score=ret[idx]["embedding"],
document=obj.documents[idx],
index=idx,
meta_info=ret[idx]["meta_info"],
)
)
response.sort(key=lambda x: x.score, reverse=True)
return response
async def v1_rerank(tokenizer_manager, obj: V1RerankReqInput, raw_request: Request):
adapted_request = v1_rerank_request(obj)
try:
ret = await tokenizer_manager.generate_request(
adapted_request, raw_request
).__anext__()
except ValueError as e:
return create_error_response(str(e))
if not isinstance(ret, list):
ret = [ret]
response = v1_rerank_response(
ret,
obj,
)
return response
def to_openai_style_logprobs(
input_token_logprobs=None,
output_token_logprobs=None,
input_top_logprobs=None,
output_top_logprobs=None,
):
ret_logprobs = LogProbs()
def append_token_logprobs(token_logprobs):
for logprob, _, token_text in token_logprobs:
ret_logprobs.tokens.append(token_text)
ret_logprobs.token_logprobs.append(logprob)
# Not supported yet
ret_logprobs.text_offset.append(-1)
def append_top_logprobs(top_logprobs):
for tokens in top_logprobs:
if tokens is not None:
ret_logprobs.top_logprobs.append(
{token[2]: token[0] for token in tokens}
)
else:
ret_logprobs.top_logprobs.append(None)
if input_token_logprobs is not None:
append_token_logprobs(input_token_logprobs)
if output_token_logprobs is not None:
append_token_logprobs(output_token_logprobs)
if input_top_logprobs is not None:
append_top_logprobs(input_top_logprobs)
if output_top_logprobs is not None:
append_top_logprobs(output_top_logprobs)
return ret_logprobs
async def v1_score(tokenizer_manager, raw_request):
try:
# Parse request
request_data = await raw_request.json()
request = ScoringRequest(**request_data)
# Use tokenizer_manager's score_request method directly
scores = await tokenizer_manager.score_request(
query=request.query,
items=request.items,
label_token_ids=request.label_token_ids,
apply_softmax=request.apply_softmax,
item_first=request.item_first,
request=request,
)
# Create response with just the scores, without usage info
response = ScoringResponse(
scores=scores,
model=request.model,
)
return response
except Exception as e:
logger.error(f"Error in v1_score: {str(e)}")
return create_error_response(str(e))
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Pydantic models for OpenAI API protocol"""
import time
from typing import Dict, List, Optional, Union
from pydantic import BaseModel, Field, model_serializer, root_validator
from typing_extensions import Literal
class ModelCard(BaseModel):
"""Model cards."""
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "sglang"
root: Optional[str] = None
max_model_len: Optional[int] = None
class ModelList(BaseModel):
"""Model list consists of model cards."""
object: str = "list"
data: List[ModelCard] = Field(default_factory=list)
class ErrorResponse(BaseModel):
object: str = "error"
message: str
type: str
param: Optional[str] = None
code: int
class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
class TopLogprob(BaseModel):
token: str
bytes: List[int]
logprob: float
class ChatCompletionTokenLogprob(BaseModel):
token: str
bytes: List[int]
logprob: float
top_logprobs: List[TopLogprob]
class ChoiceLogprobs(BaseModel):
# build for v1/chat/completions response
content: List[ChatCompletionTokenLogprob]
class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
# only used to return cached tokens when --enable-cache-report is set
prompt_tokens_details: Optional[Dict[str, int]] = None
class StreamOptions(BaseModel):
include_usage: Optional[bool] = False
class JsonSchemaResponseFormat(BaseModel):
name: str
description: Optional[str] = None
# use alias to workaround pydantic conflict
schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
strict: Optional[bool] = False
class FileRequest(BaseModel):
# https://platform.openai.com/docs/api-reference/files/create
file: bytes # The File object (not file name) to be uploaded
purpose: str = (
"batch" # The intended purpose of the uploaded file, default is "batch"
)
class FileResponse(BaseModel):
id: str
object: str = "file"
bytes: int
created_at: int
filename: str
purpose: str
class FileDeleteResponse(BaseModel):
id: str
object: str = "file"
deleted: bool
class BatchRequest(BaseModel):
input_file_id: (
str # The ID of an uploaded file that contains requests for the new batch
)
endpoint: str # The endpoint to be used for all requests in the batch
completion_window: str # The time frame within which the batch should be processed
metadata: Optional[dict] = None # Optional custom metadata for the batch
class BatchResponse(BaseModel):
id: str
object: str = "batch"
endpoint: str
errors: Optional[dict] = None
input_file_id: str
completion_window: str
status: str = "validating"
output_file_id: Optional[str] = None
error_file_id: Optional[str] = None
created_at: int
in_progress_at: Optional[int] = None
expires_at: Optional[int] = None
finalizing_at: Optional[int] = None
completed_at: Optional[int] = None
failed_at: Optional[int] = None
expired_at: Optional[int] = None
cancelling_at: Optional[int] = None
cancelled_at: Optional[int] = None
request_counts: Optional[dict] = None
metadata: Optional[dict] = None
class CompletionRequest(BaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/completions/create
model: str
prompt: Union[List[int], List[List[int]], str, List[str]]
best_of: Optional[int] = None
echo: bool = False
frequency_penalty: float = 0.0
logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[int] = None
max_tokens: int = 16
n: int = 1
presence_penalty: float = 0.0
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
stream: bool = False
stream_options: Optional[StreamOptions] = None
suffix: Optional[str] = None
temperature: float = 1.0
top_p: float = 1.0
user: Optional[str] = None
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
top_k: int = -1
min_p: float = 0.0
min_tokens: int = 0
json_schema: Optional[str] = None
regex: Optional[str] = None
ebnf: Optional[str] = None
repetition_penalty: float = 1.0
stop_token_ids: Optional[List[int]] = None
no_stop_trim: bool = False
ignore_eos: bool = False
skip_special_tokens: bool = True
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
session_params: Optional[Dict] = None
return_hidden_states: Optional[bool] = False
# For PD disaggregation
bootstrap_host: Optional[str] = None
bootstrap_port: Optional[int] = None
bootstrap_room: Optional[int] = None
class CompletionResponseChoice(BaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Literal["stop", "length", "content_filter", "abort"]
matched_stop: Union[None, int, str] = None
hidden_states: Optional[object] = None
@model_serializer
def _serialize(self):
return exclude_if_none(self, ["hidden_states"])
class CompletionResponse(BaseModel):
id: str
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseChoice]
usage: UsageInfo
class CompletionResponseStreamChoice(BaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
matched_stop: Union[None, int, str] = None
hidden_states: Optional[object] = None
@model_serializer
def _serialize(self):
return exclude_if_none(self, ["hidden_states"])
class CompletionStreamResponse(BaseModel):
id: str
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseStreamChoice]
usage: Optional[UsageInfo] = None
class ChatCompletionMessageContentTextPart(BaseModel):
type: Literal["text"]
text: str
class ChatCompletionMessageContentImageURL(BaseModel):
url: str
detail: Optional[Literal["auto", "low", "high"]] = "auto"
class ChatCompletionMessageContentAudioURL(BaseModel):
url: str
class ChatCompletionMessageContentImagePart(BaseModel):
type: Literal["image_url"]
image_url: ChatCompletionMessageContentImageURL
modalities: Optional[Literal["image", "multi-images", "video"]] = "image"
class ChatCompletionMessageContentAudioPart(BaseModel):
type: Literal["audio_url"]
audio_url: ChatCompletionMessageContentAudioURL
ChatCompletionMessageContentPart = Union[
ChatCompletionMessageContentTextPart,
ChatCompletionMessageContentImagePart,
ChatCompletionMessageContentAudioPart,
]
class FunctionResponse(BaseModel):
"""Function response."""
name: Optional[str] = None
arguments: Optional[str] = None
class ToolCall(BaseModel):
"""Tool call response."""
id: Optional[str] = None
index: Optional[int] = None
type: Literal["function"] = "function"
function: FunctionResponse
class ChatCompletionMessageGenericParam(BaseModel):
role: Literal["system", "assistant", "tool"]
content: Union[str, List[ChatCompletionMessageContentTextPart], None]
tool_call_id: Optional[str] = None
name: Optional[str] = None
reasoning_content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
class ChatCompletionMessageUserParam(BaseModel):
role: Literal["user"]
content: Union[str, List[ChatCompletionMessageContentPart]]
ChatCompletionMessageParam = Union[
ChatCompletionMessageGenericParam, ChatCompletionMessageUserParam
]
class ResponseFormat(BaseModel):
type: Literal["text", "json_object", "json_schema"]
json_schema: Optional[JsonSchemaResponseFormat] = None
class StructuresResponseFormat(BaseModel):
begin: str
schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
end: str
class StructuralTagResponseFormat(BaseModel):
type: Literal["structural_tag"]
structures: List[StructuresResponseFormat]
triggers: List[str]
class Function(BaseModel):
"""Function descriptions."""
description: Optional[str] = Field(default=None, examples=[None])
name: Optional[str] = None
parameters: Optional[object] = None
strict: bool = False
class Tool(BaseModel):
"""Function wrapper."""
type: str = Field(default="function", examples=["function"])
function: Function
class ToolChoiceFuncName(BaseModel):
"""The name of tool choice function."""
name: Optional[str] = None
class ToolChoice(BaseModel):
"""The tool choice definition."""
function: ToolChoiceFuncName
type: Literal["function"] = Field(default="function", examples=["function"])
class ChatCompletionRequest(BaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create
messages: List[ChatCompletionMessageParam]
model: str
frequency_penalty: float = 0.0
logit_bias: Optional[Dict[str, float]] = None
logprobs: bool = False
top_logprobs: Optional[int] = None
max_tokens: Optional[int] = Field(
default=None,
deprecated="max_tokens is deprecated in favor of the max_completion_tokens field",
description="The maximum number of tokens that can be generated in the chat completion. ",
)
max_completion_tokens: Optional[int] = Field(
default=None,
description="The maximum number of completion tokens for a chat completion request, "
"including visible output tokens and reasoning tokens. Input tokens are not included. ",
)
n: int = 1
presence_penalty: float = 0.0
response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
stream: bool = False
stream_options: Optional[StreamOptions] = None
temperature: float = 0.7
top_p: float = 1.0
user: Optional[str] = None
tools: Optional[List[Tool]] = Field(default=None, examples=[None])
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field(
default="auto", examples=["none"]
) # noqa
@root_validator(pre=True)
def set_tool_choice_default(cls, values):
if values.get("tool_choice") is None:
if values.get("tools") is None:
values["tool_choice"] = "none"
else:
values["tool_choice"] = "auto"
return values
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
top_k: int = -1
min_p: float = 0.0
min_tokens: int = 0
regex: Optional[str] = None
ebnf: Optional[str] = None
repetition_penalty: float = 1.0
stop_token_ids: Optional[List[int]] = None
no_stop_trim: bool = False
ignore_eos: bool = False
continue_final_message: bool = False
skip_special_tokens: bool = True
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
session_params: Optional[Dict] = None
separate_reasoning: bool = True
stream_reasoning: bool = True
chat_template_kwargs: Optional[Dict] = None
# The request id.
rid: Optional[str] = None
# For PD disaggregation
bootstrap_host: Optional[str] = None
bootstrap_port: Optional[int] = None
bootstrap_room: Optional[int] = None
# Hidden States
return_hidden_states: Optional[bool] = False
class ChatMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
reasoning_content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
finish_reason: Literal[
"stop", "length", "tool_calls", "content_filter", "function_call", "abort"
]
matched_stop: Union[None, int, str] = None
hidden_states: Optional[object] = None
@model_serializer
def _serialize(self):
return exclude_if_none(self, ["hidden_states"])
class ChatCompletionResponse(BaseModel):
id: str
object: str = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseChoice]
usage: UsageInfo
class DeltaMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
reasoning_content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
hidden_states: Optional[object] = None
@model_serializer
def _serialize(self):
return exclude_if_none(self, ["hidden_states"])
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
finish_reason: Optional[
Literal["stop", "length", "tool_calls", "content_filter", "function_call"]
] = None
matched_stop: Union[None, int, str] = None
class ChatCompletionStreamResponse(BaseModel):
id: str
object: str = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]
usage: Optional[UsageInfo] = None
class MultimodalEmbeddingInput(BaseModel):
text: Optional[str] = None
image: Optional[str] = None
class EmbeddingRequest(BaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/embeddings/create
input: Union[
List[int], List[List[int]], str, List[str], List[MultimodalEmbeddingInput]
]
model: str
encoding_format: str = "float"
dimensions: int = None
user: Optional[str] = None
# The request id.
rid: Optional[str] = None
class EmbeddingObject(BaseModel):
embedding: List[float]
index: int
object: str = "embedding"
class EmbeddingResponse(BaseModel):
data: List[EmbeddingObject]
model: str
object: str = "list"
usage: Optional[UsageInfo] = None
class ScoringRequest(BaseModel):
query: Optional[Union[str, List[int]]] = (
None # Query text or pre-tokenized token IDs
)
items: Optional[Union[str, List[str], List[List[int]]]] = (
None # Item text(s) or pre-tokenized token IDs
)
label_token_ids: Optional[List[int]] = (
None # Token IDs to compute probabilities for
)
apply_softmax: bool = False
item_first: bool = False
model: str
class ScoringResponse(BaseModel):
scores: List[
List[float]
] # List of lists of probabilities, each in the order of label_token_ids
model: str
usage: Optional[UsageInfo] = None
object: str = "scoring"
class RerankResponse(BaseModel):
score: float
document: str
index: int
meta_info: Optional[dict] = None
def exclude_if_none(obj, field_names: List[str]):
omit_if_none_fields = {k for k, v in obj.model_fields.items() if k in field_names}
return {k: v for k, v in obj if k not in omit_if_none_fields or v is not None}
from typing import Dict, Tuple
from typing import Dict, Optional, Tuple, Type
class StreamingParseResult:
......@@ -32,17 +32,26 @@ class BaseReasoningFormatDetector:
One-time parsing: Detects and parses reasoning sections in the provided text.
Returns both reasoning content and normal text separately.
"""
text = text.replace(self.think_start_token, "").strip()
if self.think_end_token not in text:
in_reasoning = self._in_reasoning or text.startswith(self.think_start_token)
if not in_reasoning:
return StreamingParseResult(normal_text=text)
# The text is considered to be in a reasoning block.
processed_text = text.replace(self.think_start_token, "").strip()
if self.think_end_token not in processed_text:
# Assume reasoning was truncated before `</think>` token
return StreamingParseResult(reasoning_text=text)
return StreamingParseResult(reasoning_text=processed_text)
# Extract reasoning content
splits = text.split(self.think_end_token, maxsplit=1)
splits = processed_text.split(self.think_end_token, maxsplit=1)
reasoning_text = splits[0]
text = splits[1].strip()
normal_text = splits[1].strip()
return StreamingParseResult(normal_text=text, reasoning_text=reasoning_text)
return StreamingParseResult(
normal_text=normal_text, reasoning_text=reasoning_text
)
def parse_streaming_increment(self, new_text: str) -> StreamingParseResult:
"""
......@@ -61,6 +70,7 @@ class BaseReasoningFormatDetector:
if not self.stripped_think_start and self.think_start_token in current_text:
current_text = current_text.replace(self.think_start_token, "")
self.stripped_think_start = True
self._in_reasoning = True
# Handle end of reasoning block
if self._in_reasoning and self.think_end_token in current_text:
......@@ -131,11 +141,11 @@ class Qwen3Detector(BaseReasoningFormatDetector):
"""
def __init__(self, stream_reasoning: bool = True):
# Qwen3 is assumed to be reasoning until `</think>` token
# Qwen3 won't be in reasoning mode when user passes `enable_thinking=False`
super().__init__(
"<think>",
"</think>",
force_reasoning=True,
force_reasoning=False,
stream_reasoning=stream_reasoning,
)
......@@ -151,12 +161,12 @@ class ReasoningParser:
If True, streams reasoning content as it arrives.
"""
DetectorMap: Dict[str, BaseReasoningFormatDetector] = {
DetectorMap: Dict[str, Type[BaseReasoningFormatDetector]] = {
"deepseek-r1": DeepSeekR1Detector,
"qwen3": Qwen3Detector,
}
def __init__(self, model_type: str = None, stream_reasoning: bool = True):
def __init__(self, model_type: Optional[str] = None, stream_reasoning: bool = True):
if not model_type:
raise ValueError("Model type must be specified")
......
# sglang/test/srt/openai/conftest.py
import os
import socket
import subprocess
import sys
import tempfile
import time
from contextlib import closing
from typing import Generator
import pytest
import requests
from sglang.srt.utils import kill_process_tree # reuse SGLang helper
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
SERVER_MODULE = "sglang.srt.entrypoints.openai.api_server"
DEFAULT_MODEL = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
STARTUP_TIMEOUT = float(os.getenv("SGLANG_OPENAI_STARTUP_TIMEOUT", 120))
def _pick_free_port() -> int:
with closing(socket.socket()) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
def _wait_until_healthy(proc: subprocess.Popen, base: str, timeout: float) -> None:
start = time.perf_counter()
while time.perf_counter() - start < timeout:
if proc.poll() is not None: # crashed
raise RuntimeError("api_server terminated prematurely")
try:
if requests.get(f"{base}/health", timeout=1).status_code == 200:
return
except requests.RequestException:
pass
time.sleep(0.4)
raise RuntimeError("api_server readiness probe timed out")
def launch_openai_server(model: str = DEFAULT_MODEL, **kw):
"""Spawn the draft OpenAI-compatible server and wait until it's ready."""
port = _pick_free_port()
cmd = [
sys.executable,
"-m",
SERVER_MODULE,
"--model-path",
model,
"--host",
"127.0.0.1",
"--port",
str(port),
*map(str, kw.get("args", [])),
]
env = {**os.environ, **kw.get("env", {})}
# Write logs to a temp file so the child never blocks on a full pipe.
log_file = tempfile.NamedTemporaryFile("w+", delete=False)
proc = subprocess.Popen(
cmd,
env=env,
stdout=log_file,
stderr=subprocess.STDOUT,
text=True,
)
base = f"http://127.0.0.1:{port}"
try:
_wait_until_healthy(proc, base, STARTUP_TIMEOUT)
except Exception as e:
proc.terminate()
proc.wait(5)
log_file.seek(0)
print("\n--- api_server log ---\n", log_file.read(), file=sys.stderr)
raise e
return proc, base, log_file
@pytest.fixture(scope="session")
def openai_server() -> Generator[str, None, None]:
"""PyTest fixture that provides the server's base URL and cleans up."""
proc, base, log_file = launch_openai_server()
yield base
kill_process_tree(proc.pid)
log_file.close()
......@@ -67,29 +67,6 @@ from sglang.srt.entrypoints.openai.protocol import (
class TestModelCard(unittest.TestCase):
"""Test ModelCard protocol model"""
def test_basic_model_card_creation(self):
"""Test basic model card creation with required fields"""
card = ModelCard(id="test-model")
self.assertEqual(card.id, "test-model")
self.assertEqual(card.object, "model")
self.assertEqual(card.owned_by, "sglang")
self.assertIsInstance(card.created, int)
self.assertIsNone(card.root)
self.assertIsNone(card.max_model_len)
def test_model_card_with_optional_fields(self):
"""Test model card with optional fields"""
card = ModelCard(
id="test-model",
root="/path/to/model",
max_model_len=2048,
created=1234567890,
)
self.assertEqual(card.id, "test-model")
self.assertEqual(card.root, "/path/to/model")
self.assertEqual(card.max_model_len, 2048)
self.assertEqual(card.created, 1234567890)
def test_model_card_serialization(self):
"""Test model card JSON serialization"""
card = ModelCard(id="test-model", max_model_len=4096)
......@@ -120,53 +97,6 @@ class TestModelList(unittest.TestCase):
self.assertEqual(model_list.data[1].id, "model-2")
class TestErrorResponse(unittest.TestCase):
"""Test ErrorResponse protocol model"""
def test_basic_error_response(self):
"""Test basic error response creation"""
error = ErrorResponse(
message="Invalid request", type="BadRequestError", code=400
)
self.assertEqual(error.object, "error")
self.assertEqual(error.message, "Invalid request")
self.assertEqual(error.type, "BadRequestError")
self.assertEqual(error.code, 400)
self.assertIsNone(error.param)
def test_error_response_with_param(self):
"""Test error response with parameter"""
error = ErrorResponse(
message="Invalid temperature",
type="ValidationError",
code=422,
param="temperature",
)
self.assertEqual(error.param, "temperature")
class TestUsageInfo(unittest.TestCase):
"""Test UsageInfo protocol model"""
def test_basic_usage_info(self):
"""Test basic usage info creation"""
usage = UsageInfo(prompt_tokens=10, completion_tokens=20, total_tokens=30)
self.assertEqual(usage.prompt_tokens, 10)
self.assertEqual(usage.completion_tokens, 20)
self.assertEqual(usage.total_tokens, 30)
self.assertIsNone(usage.prompt_tokens_details)
def test_usage_info_with_cache_details(self):
"""Test usage info with cache details"""
usage = UsageInfo(
prompt_tokens=10,
completion_tokens=20,
total_tokens=30,
prompt_tokens_details={"cached_tokens": 5},
)
self.assertEqual(usage.prompt_tokens_details, {"cached_tokens": 5})
class TestCompletionRequest(unittest.TestCase):
"""Test CompletionRequest protocol model"""
......@@ -181,30 +111,6 @@ class TestCompletionRequest(unittest.TestCase):
self.assertFalse(request.stream) # default
self.assertFalse(request.echo) # default
def test_completion_request_with_options(self):
"""Test completion request with various options"""
request = CompletionRequest(
model="test-model",
prompt=["Hello", "world"],
max_tokens=100,
temperature=0.7,
top_p=0.9,
n=2,
stream=True,
echo=True,
stop=[".", "!"],
logprobs=5,
)
self.assertEqual(request.prompt, ["Hello", "world"])
self.assertEqual(request.max_tokens, 100)
self.assertEqual(request.temperature, 0.7)
self.assertEqual(request.top_p, 0.9)
self.assertEqual(request.n, 2)
self.assertTrue(request.stream)
self.assertTrue(request.echo)
self.assertEqual(request.stop, [".", "!"])
self.assertEqual(request.logprobs, 5)
def test_completion_request_sglang_extensions(self):
"""Test completion request with SGLang-specific extensions"""
request = CompletionRequest(
......@@ -233,26 +139,6 @@ class TestCompletionRequest(unittest.TestCase):
CompletionRequest(model="test-model") # missing prompt
class TestCompletionResponse(unittest.TestCase):
"""Test CompletionResponse protocol model"""
def test_basic_completion_response(self):
"""Test basic completion response"""
choice = CompletionResponseChoice(
index=0, text="Hello world!", finish_reason="stop"
)
usage = UsageInfo(prompt_tokens=2, completion_tokens=3, total_tokens=5)
response = CompletionResponse(
id="test-id", model="test-model", choices=[choice], usage=usage
)
self.assertEqual(response.id, "test-id")
self.assertEqual(response.object, "text_completion")
self.assertEqual(response.model, "test-model")
self.assertEqual(len(response.choices), 1)
self.assertEqual(response.choices[0].text, "Hello world!")
self.assertEqual(response.usage.total_tokens, 5)
class TestChatCompletionRequest(unittest.TestCase):
"""Test ChatCompletionRequest protocol model"""
......@@ -268,48 +154,6 @@ class TestChatCompletionRequest(unittest.TestCase):
self.assertFalse(request.stream) # default
self.assertEqual(request.tool_choice, "none") # default when no tools
def test_chat_completion_with_multimodal_content(self):
"""Test chat completion with multimodal content"""
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,/9j/4AAQ..."},
},
],
}
]
request = ChatCompletionRequest(model="test-model", messages=messages)
self.assertEqual(len(request.messages[0].content), 2)
self.assertEqual(request.messages[0].content[0].type, "text")
self.assertEqual(request.messages[0].content[1].type, "image_url")
def test_chat_completion_with_tools(self):
"""Test chat completion with tools"""
messages = [{"role": "user", "content": "What's the weather?"}]
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather information",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
},
},
}
]
request = ChatCompletionRequest(
model="test-model", messages=messages, tools=tools
)
self.assertEqual(len(request.tools), 1)
self.assertEqual(request.tools[0].function.name, "get_weather")
self.assertEqual(request.tool_choice, "auto") # default when tools present
def test_chat_completion_tool_choice_validation(self):
"""Test tool choice validation logic"""
messages = [{"role": "user", "content": "Hello"}]
......@@ -349,289 +193,6 @@ class TestChatCompletionRequest(unittest.TestCase):
self.assertEqual(request.chat_template_kwargs, {"custom_param": "value"})
class TestChatCompletionResponse(unittest.TestCase):
"""Test ChatCompletionResponse protocol model"""
def test_basic_chat_completion_response(self):
"""Test basic chat completion response"""
message = ChatMessage(role="assistant", content="Hello there!")
choice = ChatCompletionResponseChoice(
index=0, message=message, finish_reason="stop"
)
usage = UsageInfo(prompt_tokens=2, completion_tokens=3, total_tokens=5)
response = ChatCompletionResponse(
id="test-id", model="test-model", choices=[choice], usage=usage
)
self.assertEqual(response.id, "test-id")
self.assertEqual(response.object, "chat.completion")
self.assertEqual(response.model, "test-model")
self.assertEqual(len(response.choices), 1)
self.assertEqual(response.choices[0].message.content, "Hello there!")
def test_chat_completion_response_with_tool_calls(self):
"""Test chat completion response with tool calls"""
tool_call = ToolCall(
id="call_123",
function=FunctionResponse(
name="get_weather", arguments='{"location": "San Francisco"}'
),
)
message = ChatMessage(role="assistant", content=None, tool_calls=[tool_call])
choice = ChatCompletionResponseChoice(
index=0, message=message, finish_reason="tool_calls"
)
usage = UsageInfo(prompt_tokens=10, completion_tokens=5, total_tokens=15)
response = ChatCompletionResponse(
id="test-id", model="test-model", choices=[choice], usage=usage
)
self.assertEqual(
response.choices[0].message.tool_calls[0].function.name, "get_weather"
)
self.assertEqual(response.choices[0].finish_reason, "tool_calls")
class TestEmbeddingRequest(unittest.TestCase):
"""Test EmbeddingRequest protocol model"""
def test_basic_embedding_request(self):
"""Test basic embedding request"""
request = EmbeddingRequest(model="test-model", input="Hello world")
self.assertEqual(request.model, "test-model")
self.assertEqual(request.input, "Hello world")
self.assertEqual(request.encoding_format, "float") # default
self.assertIsNone(request.dimensions) # default
def test_embedding_request_with_list_input(self):
"""Test embedding request with list input"""
request = EmbeddingRequest(
model="test-model", input=["Hello", "world"], dimensions=512
)
self.assertEqual(request.input, ["Hello", "world"])
self.assertEqual(request.dimensions, 512)
def test_multimodal_embedding_request(self):
"""Test multimodal embedding request"""
multimodal_input = [
MultimodalEmbeddingInput(text="Hello", image="base64_image_data"),
MultimodalEmbeddingInput(text="World", image=None),
]
request = EmbeddingRequest(model="test-model", input=multimodal_input)
self.assertEqual(len(request.input), 2)
self.assertEqual(request.input[0].text, "Hello")
self.assertEqual(request.input[0].image, "base64_image_data")
self.assertEqual(request.input[1].text, "World")
self.assertIsNone(request.input[1].image)
class TestEmbeddingResponse(unittest.TestCase):
"""Test EmbeddingResponse protocol model"""
def test_basic_embedding_response(self):
"""Test basic embedding response"""
embedding_obj = EmbeddingObject(embedding=[0.1, 0.2, 0.3], index=0)
usage = UsageInfo(prompt_tokens=3, total_tokens=3)
response = EmbeddingResponse(
data=[embedding_obj], model="test-model", usage=usage
)
self.assertEqual(response.object, "list")
self.assertEqual(len(response.data), 1)
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3])
self.assertEqual(response.data[0].index, 0)
self.assertEqual(response.usage.prompt_tokens, 3)
class TestScoringRequest(unittest.TestCase):
"""Test ScoringRequest protocol model"""
def test_basic_scoring_request(self):
"""Test basic scoring request"""
request = ScoringRequest(
model="test-model", query="Hello", items=["World", "Earth"]
)
self.assertEqual(request.model, "test-model")
self.assertEqual(request.query, "Hello")
self.assertEqual(request.items, ["World", "Earth"])
self.assertFalse(request.apply_softmax) # default
self.assertFalse(request.item_first) # default
def test_scoring_request_with_token_ids(self):
"""Test scoring request with token IDs"""
request = ScoringRequest(
model="test-model",
query=[1, 2, 3],
items=[[4, 5], [6, 7]],
label_token_ids=[8, 9],
apply_softmax=True,
item_first=True,
)
self.assertEqual(request.query, [1, 2, 3])
self.assertEqual(request.items, [[4, 5], [6, 7]])
self.assertEqual(request.label_token_ids, [8, 9])
self.assertTrue(request.apply_softmax)
self.assertTrue(request.item_first)
class TestScoringResponse(unittest.TestCase):
"""Test ScoringResponse protocol model"""
def test_basic_scoring_response(self):
"""Test basic scoring response"""
response = ScoringResponse(scores=[[0.1, 0.9], [0.3, 0.7]], model="test-model")
self.assertEqual(response.object, "scoring")
self.assertEqual(response.scores, [[0.1, 0.9], [0.3, 0.7]])
self.assertEqual(response.model, "test-model")
self.assertIsNone(response.usage) # default
class TestFileOperations(unittest.TestCase):
"""Test file operation protocol models"""
def test_file_request(self):
"""Test file request model"""
file_data = b"test file content"
request = FileRequest(file=file_data, purpose="batch")
self.assertEqual(request.file, file_data)
self.assertEqual(request.purpose, "batch")
def test_file_response(self):
"""Test file response model"""
response = FileResponse(
id="file-123",
bytes=1024,
created_at=1234567890,
filename="test.jsonl",
purpose="batch",
)
self.assertEqual(response.id, "file-123")
self.assertEqual(response.object, "file")
self.assertEqual(response.bytes, 1024)
self.assertEqual(response.filename, "test.jsonl")
def test_file_delete_response(self):
"""Test file delete response model"""
response = FileDeleteResponse(id="file-123", deleted=True)
self.assertEqual(response.id, "file-123")
self.assertEqual(response.object, "file")
self.assertTrue(response.deleted)
class TestBatchOperations(unittest.TestCase):
"""Test batch operation protocol models"""
def test_batch_request(self):
"""Test batch request model"""
request = BatchRequest(
input_file_id="file-123",
endpoint="/v1/chat/completions",
completion_window="24h",
metadata={"custom": "value"},
)
self.assertEqual(request.input_file_id, "file-123")
self.assertEqual(request.endpoint, "/v1/chat/completions")
self.assertEqual(request.completion_window, "24h")
self.assertEqual(request.metadata, {"custom": "value"})
def test_batch_response(self):
"""Test batch response model"""
response = BatchResponse(
id="batch-123",
endpoint="/v1/chat/completions",
input_file_id="file-123",
completion_window="24h",
created_at=1234567890,
)
self.assertEqual(response.id, "batch-123")
self.assertEqual(response.object, "batch")
self.assertEqual(response.status, "validating") # default
self.assertEqual(response.endpoint, "/v1/chat/completions")
class TestResponseFormats(unittest.TestCase):
"""Test response format protocol models"""
def test_basic_response_format(self):
"""Test basic response format"""
format_obj = ResponseFormat(type="json_object")
self.assertEqual(format_obj.type, "json_object")
self.assertIsNone(format_obj.json_schema)
def test_json_schema_response_format(self):
"""Test JSON schema response format"""
schema = {"type": "object", "properties": {"name": {"type": "string"}}}
json_schema = JsonSchemaResponseFormat(
name="person_schema", description="Person schema", schema=schema
)
format_obj = ResponseFormat(type="json_schema", json_schema=json_schema)
self.assertEqual(format_obj.type, "json_schema")
self.assertEqual(format_obj.json_schema.name, "person_schema")
self.assertEqual(format_obj.json_schema.schema_, schema)
def test_structural_tag_response_format(self):
"""Test structural tag response format"""
structures = [
{
"begin": "<thinking>",
"schema_": {"type": "string"},
"end": "</thinking>",
}
]
format_obj = StructuralTagResponseFormat(
type="structural_tag", structures=structures, triggers=["think"]
)
self.assertEqual(format_obj.type, "structural_tag")
self.assertEqual(len(format_obj.structures), 1)
self.assertEqual(format_obj.triggers, ["think"])
class TestLogProbs(unittest.TestCase):
"""Test LogProbs protocol models"""
def test_basic_logprobs(self):
"""Test basic LogProbs model"""
logprobs = LogProbs(
text_offset=[0, 5, 11],
token_logprobs=[-0.1, -0.2, -0.3],
tokens=["Hello", " ", "world"],
top_logprobs=[{"Hello": -0.1}, {" ": -0.2}, {"world": -0.3}],
)
self.assertEqual(len(logprobs.tokens), 3)
self.assertEqual(logprobs.tokens, ["Hello", " ", "world"])
self.assertEqual(logprobs.token_logprobs, [-0.1, -0.2, -0.3])
def test_choice_logprobs(self):
"""Test ChoiceLogprobs model"""
token_logprob = ChatCompletionTokenLogprob(
token="Hello",
bytes=[72, 101, 108, 108, 111],
logprob=-0.1,
top_logprobs=[
TopLogprob(token="Hello", bytes=[72, 101, 108, 108, 111], logprob=-0.1)
],
)
choice_logprobs = ChoiceLogprobs(content=[token_logprob])
self.assertEqual(len(choice_logprobs.content), 1)
self.assertEqual(choice_logprobs.content[0].token, "Hello")
class TestStreamingModels(unittest.TestCase):
"""Test streaming response models"""
def test_stream_options(self):
"""Test StreamOptions model"""
options = StreamOptions(include_usage=True)
self.assertTrue(options.include_usage)
def test_chat_completion_stream_response(self):
"""Test ChatCompletionStreamResponse model"""
delta = DeltaMessage(role="assistant", content="Hello")
choice = ChatCompletionResponseStreamChoice(index=0, delta=delta)
response = ChatCompletionStreamResponse(
id="test-id", model="test-model", choices=[choice]
)
self.assertEqual(response.object, "chat.completion.chunk")
self.assertEqual(response.choices[0].delta.content, "Hello")
class TestModelSerialization(unittest.TestCase):
"""Test model serialization with hidden states"""
......@@ -680,11 +241,6 @@ class TestModelSerialization(unittest.TestCase):
class TestValidationEdgeCases(unittest.TestCase):
"""Test edge cases and validation scenarios"""
def test_empty_messages_validation(self):
"""Test validation with empty messages"""
with self.assertRaises(ValidationError):
ChatCompletionRequest(model="test-model", messages=[])
def test_invalid_tool_choice_type(self):
"""Test invalid tool choice type"""
messages = [{"role": "user", "content": "Hello"}]
......@@ -698,13 +254,6 @@ class TestValidationEdgeCases(unittest.TestCase):
with self.assertRaises(ValidationError):
CompletionRequest(model="test-model", prompt="Hello", max_tokens=-1)
def test_invalid_temperature_range(self):
"""Test invalid temperature values"""
# Note: The current protocol doesn't enforce temperature range,
# but this test documents expected behavior
request = CompletionRequest(model="test-model", prompt="Hello", temperature=5.0)
self.assertEqual(request.temperature, 5.0) # Currently allowed
def test_model_serialization_roundtrip(self):
"""Test that models can be serialized and deserialized"""
original_request = ChatCompletionRequest(
......
# sglang/test/srt/openai/test_server.py
import requests
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST as MODEL_ID
def test_health(openai_server: str):
r = requests.get(f"{openai_server}/health")
assert r.status_code == 200
# FastAPI returns an empty body → r.text == ""
assert r.text == ""
def test_models_endpoint(openai_server: str):
r = requests.get(f"{openai_server}/v1/models")
assert r.status_code == 200, r.text
payload = r.json()
# Basic contract
assert "data" in payload and isinstance(payload["data"], list) and payload["data"]
# Validate fields of the first model card
first = payload["data"][0]
for key in ("id", "root", "max_model_len"):
assert key in first, f"missing {key} in {first}"
# max_model_len must be positive
assert isinstance(first["max_model_len"], int) and first["max_model_len"] > 0
# The server should report the same model id we launched it with
ids = {m["id"] for m in payload["data"]}
assert MODEL_ID in ids
def test_get_model_info(openai_server: str):
r = requests.get(f"{openai_server}/get_model_info")
assert r.status_code == 200, r.text
info = r.json()
expected_keys = {"model_path", "tokenizer_path", "is_generation"}
assert expected_keys.issubset(info.keys())
# model_path must end with the one we passed on the CLI
assert info["model_path"].endswith(MODEL_ID)
# is_generation is documented as a boolean
assert isinstance(info["is_generation"], bool)
def test_unknown_route_returns_404(openai_server: str):
r = requests.get(f"{openai_server}/definitely-not-a-real-route")
assert r.status_code == 404
......@@ -57,11 +57,21 @@ class _MockTokenizerManager:
self.create_abort_task = Mock()
class _MockTemplateManager:
"""Minimal mock for TemplateManager."""
def __init__(self):
self.chat_template_name: Optional[str] = "llama-3"
self.jinja_template_content_format: Optional[str] = None
self.completion_template_name: Optional[str] = None
class ServingChatTestCase(unittest.TestCase):
# ------------- common fixtures -------------
def setUp(self):
self.tm = _MockTokenizerManager()
self.chat = OpenAIServingChat(self.tm)
self.template_manager = _MockTemplateManager()
self.chat = OpenAIServingChat(self.tm, self.template_manager)
# frequently reused requests
self.basic_req = ChatCompletionRequest(
......@@ -109,96 +119,6 @@ class ServingChatTestCase(unittest.TestCase):
self.assertFalse(adapted.stream)
self.assertEqual(processed, self.basic_req)
# # ------------- tool-call branch -------------
# def test_tool_call_request_conversion(self):
# req = ChatCompletionRequest(
# model="x",
# messages=[{"role": "user", "content": "Weather?"}],
# tools=[
# {
# "type": "function",
# "function": {
# "name": "get_weather",
# "parameters": {"type": "object", "properties": {}},
# },
# }
# ],
# tool_choice="auto",
# )
# with patch.object(
# self.chat,
# "_process_messages",
# return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None),
# ):
# adapted, _ = self.chat._convert_to_internal_request(req, "rid")
# self.assertEqual(adapted.rid, "rid")
# def test_tool_choice_none(self):
# req = ChatCompletionRequest(
# model="x",
# messages=[{"role": "user", "content": "Hi"}],
# tools=[{"type": "function", "function": {"name": "noop"}}],
# tool_choice="none",
# )
# with patch.object(
# self.chat,
# "_process_messages",
# return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None),
# ):
# adapted, _ = self.chat._convert_to_internal_request(req, "rid")
# self.assertEqual(adapted.rid, "rid")
# ------------- multimodal branch -------------
def test_multimodal_request_with_images(self):
self.tm.model_config.is_multimodal = True
req = ChatCompletionRequest(
model="x",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What's in the image?"},
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,"},
},
],
}
],
)
with patch.object(
self.chat,
"_apply_jinja_template",
return_value=("prompt", [1, 2], ["img"], None, [], []),
), patch.object(
self.chat,
"_apply_conversation_template",
return_value=("prompt", ["img"], None, [], []),
):
out = self.chat._process_messages(req, True)
_, _, image_data, *_ = out
self.assertEqual(image_data, ["img"])
# ------------- template handling -------------
def test_jinja_template_processing(self):
req = ChatCompletionRequest(
model="x", messages=[{"role": "user", "content": "Hello"}]
)
self.tm.chat_template_name = None
self.tm.tokenizer.chat_template = "<jinja>"
with patch.object(
self.chat,
"_apply_jinja_template",
return_value=("processed", [1], None, None, [], ["</s>"]),
), patch("builtins.hasattr", return_value=True):
prompt, prompt_ids, *_ = self.chat._process_messages(req, False)
self.assertEqual(prompt, "processed")
self.assertEqual(prompt_ids, [1])
# ------------- sampling-params -------------
def test_sampling_param_build(self):
req = ChatCompletionRequest(
......
......@@ -5,6 +5,7 @@ Run with:
"""
import unittest
from typing import Optional
from unittest.mock import AsyncMock, Mock, patch
from sglang.srt.entrypoints.openai.protocol import CompletionRequest
......@@ -12,6 +13,17 @@ from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompl
from sglang.srt.managers.tokenizer_manager import TokenizerManager
class _MockTemplateManager:
"""Minimal mock for TemplateManager."""
def __init__(self):
self.chat_template_name: Optional[str] = None
self.jinja_template_content_format: Optional[str] = None
self.completion_template_name: Optional[str] = (
None # Set to None to avoid template processing
)
class ServingCompletionTestCase(unittest.TestCase):
"""Bundle all prompt/echo tests in one TestCase."""
......@@ -31,7 +43,8 @@ class ServingCompletionTestCase(unittest.TestCase):
tm.generate_request = AsyncMock()
tm.create_abort_task = Mock()
self.sc = OpenAIServingCompletion(tm)
self.template_manager = _MockTemplateManager()
self.sc = OpenAIServingCompletion(tm, self.template_manager)
# ---------- prompt-handling ----------
def test_single_string_prompt(self):
......@@ -44,20 +57,6 @@ class ServingCompletionTestCase(unittest.TestCase):
internal, _ = self.sc._convert_to_internal_request(req)
self.assertEqual(internal.input_ids, [1, 2, 3, 4])
def test_completion_template_handling(self):
req = CompletionRequest(
model="x", prompt="def f():", suffix="return 1", max_tokens=100
)
with patch(
"sglang.srt.entrypoints.openai.serving_completions.is_completion_template_defined",
return_value=True,
), patch(
"sglang.srt.entrypoints.openai.serving_completions.generate_completion_prompt_from_request",
return_value="processed_prompt",
):
internal, _ = self.sc._convert_to_internal_request(req)
self.assertEqual(internal.text, "processed_prompt")
# ---------- echo-handling ----------
def test_echo_with_string_prompt_streaming(self):
req = CompletionRequest(model="x", prompt="Hello", max_tokens=1, echo=True)
......
......@@ -5,25 +5,16 @@ These tests ensure that the embedding serving implementation maintains compatibi
with the original adapter.py functionality and follows OpenAI API specifications.
"""
import asyncio
import json
import time
import unittest
import uuid
from typing import Any, Dict, List
from unittest.mock import AsyncMock, Mock, patch
from unittest.mock import Mock
from fastapi import Request
from fastapi.responses import ORJSONResponse
from pydantic_core import ValidationError
from sglang.srt.entrypoints.openai.protocol import (
EmbeddingObject,
EmbeddingRequest,
EmbeddingResponse,
ErrorResponse,
MultimodalEmbeddingInput,
UsageInfo,
)
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from sglang.srt.managers.io_struct import EmbeddingReqInput
......@@ -58,11 +49,22 @@ class _MockTokenizerManager:
self.generate_request = Mock(return_value=mock_generate_embedding())
# Mock TemplateManager for embedding tests
class _MockTemplateManager:
def __init__(self):
self.chat_template_name = None # None for embeddings usually
self.jinja_template_content_format = None
self.completion_template_name = None
class ServingEmbeddingTestCase(unittest.TestCase):
def setUp(self):
"""Set up test fixtures."""
self.tokenizer_manager = _MockTokenizerManager()
self.serving_embedding = OpenAIServingEmbedding(self.tokenizer_manager)
self.template_manager = _MockTemplateManager()
self.serving_embedding = OpenAIServingEmbedding(
self.tokenizer_manager, self.template_manager
)
self.request = Mock(spec=Request)
self.request.headers = {}
......@@ -141,132 +143,6 @@ class ServingEmbeddingTestCase(unittest.TestCase):
self.assertIsNone(adapted_request.image_data[1])
# self.assertEqual(adapted_request.rid, "test-id")
def test_build_single_embedding_response(self):
"""Test building response for single embedding."""
ret_data = [
{
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
"meta_info": {"prompt_tokens": 5},
}
]
response = self.serving_embedding._build_embedding_response(ret_data)
self.assertIsInstance(response, EmbeddingResponse)
self.assertEqual(response.model, "test-model")
self.assertEqual(len(response.data), 1)
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
self.assertEqual(response.data[0].index, 0)
self.assertEqual(response.data[0].object, "embedding")
self.assertEqual(response.usage.prompt_tokens, 5)
self.assertEqual(response.usage.total_tokens, 5)
self.assertEqual(response.usage.completion_tokens, 0)
def test_build_multiple_embedding_response(self):
"""Test building response for multiple embeddings."""
ret_data = [
{
"embedding": [0.1, 0.2, 0.3],
"meta_info": {"prompt_tokens": 3},
},
{
"embedding": [0.4, 0.5, 0.6],
"meta_info": {"prompt_tokens": 4},
},
]
response = self.serving_embedding._build_embedding_response(ret_data)
self.assertIsInstance(response, EmbeddingResponse)
self.assertEqual(len(response.data), 2)
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3])
self.assertEqual(response.data[0].index, 0)
self.assertEqual(response.data[1].embedding, [0.4, 0.5, 0.6])
self.assertEqual(response.data[1].index, 1)
self.assertEqual(response.usage.prompt_tokens, 7) # 3 + 4
self.assertEqual(response.usage.total_tokens, 7)
def test_handle_request_success(self):
"""Test successful embedding request handling."""
async def run_test():
# Mock the generate_request to return expected data
async def mock_generate():
yield {
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
"meta_info": {"prompt_tokens": 5},
}
self.serving_embedding.tokenizer_manager.generate_request = Mock(
return_value=mock_generate()
)
response = await self.serving_embedding.handle_request(
self.basic_req, self.request
)
self.assertIsInstance(response, EmbeddingResponse)
self.assertEqual(len(response.data), 1)
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
asyncio.run(run_test())
def test_handle_request_validation_error(self):
"""Test handling request with validation error."""
async def run_test():
invalid_request = EmbeddingRequest(model="test-model", input="")
response = await self.serving_embedding.handle_request(
invalid_request, self.request
)
self.assertIsInstance(response, ORJSONResponse)
self.assertEqual(response.status_code, 400)
asyncio.run(run_test())
def test_handle_request_generation_error(self):
"""Test handling request with generation error."""
async def run_test():
# Mock generate_request to raise an error
async def mock_generate_error():
raise ValueError("Generation failed")
yield # This won't be reached but needed for async generator
self.serving_embedding.tokenizer_manager.generate_request = Mock(
return_value=mock_generate_error()
)
response = await self.serving_embedding.handle_request(
self.basic_req, self.request
)
self.assertIsInstance(response, ORJSONResponse)
self.assertEqual(response.status_code, 400)
asyncio.run(run_test())
def test_handle_request_internal_error(self):
"""Test handling request with internal server error."""
async def run_test():
# Mock _convert_to_internal_request to raise an exception
with patch.object(
self.serving_embedding,
"_convert_to_internal_request",
side_effect=Exception("Internal error"),
):
response = await self.serving_embedding.handle_request(
self.basic_req, self.request
)
self.assertIsInstance(response, ORJSONResponse)
self.assertEqual(response.status_code, 500)
asyncio.run(run_test())
if __name__ == "__main__":
unittest.main(verbosity=2)
......@@ -29,6 +29,10 @@ suites = {
TestFile("models/test_reward_models.py", 132),
TestFile("models/test_vlm_models.py", 437),
TestFile("models/test_transformers_models.py", 320),
TestFile("openai/test_protocol.py", 10),
TestFile("openai/test_serving_chat.py", 10),
TestFile("openai/test_serving_completions.py", 10),
TestFile("openai/test_serving_embedding.py", 10),
TestFile("test_abort.py", 51),
TestFile("test_block_int8.py", 22),
TestFile("test_create_kvindices.py", 2),
......@@ -49,6 +53,7 @@ suites = {
TestFile("test_hidden_states.py", 55),
TestFile("test_int8_kernel.py", 8),
TestFile("test_input_embeddings.py", 38),
TestFile("test_jinja_template_utils.py", 1),
TestFile("test_json_constrained.py", 98),
TestFile("test_large_max_new_tokens.py", 41),
TestFile("test_metrics.py", 32),
......@@ -59,14 +64,8 @@ suites = {
TestFile("test_mla_fp8.py", 93),
TestFile("test_no_chunked_prefill.py", 108),
TestFile("test_no_overlap_scheduler.py", 234),
TestFile("test_openai_adapter.py", 1),
TestFile("test_openai_function_calling.py", 60),
TestFile("test_openai_server.py", 149),
TestFile("openai/test_server.py", 120),
TestFile("openai/test_protocol.py", 60),
TestFile("openai/test_serving_chat.py", 120),
TestFile("openai/test_serving_completions.py", 120),
TestFile("openai/test_serving_embedding.py", 120),
TestFile("test_openai_server_hidden_states.py", 240),
TestFile("test_penalty.py", 41),
TestFile("test_page_size.py", 60),
......
......@@ -3,6 +3,7 @@ import unittest
from xgrammar import GrammarCompiler, TokenizerInfo
from sglang.srt.entrypoints.openai.protocol import Function, Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
from sglang.srt.function_call.llama32_detector import Llama32Detector
......@@ -10,7 +11,6 @@ from sglang.srt.function_call.mistral_detector import MistralDetector
from sglang.srt.function_call.pythonic_detector import PythonicDetector
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.openai_api.protocol import Function, Tool
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
......
......@@ -5,8 +5,8 @@ Unit tests for OpenAI adapter utils.
import unittest
from unittest.mock import patch
from sglang.srt.openai_api.utils import (
detect_template_content_format,
from sglang.srt.jinja_template_utils import (
detect_jinja_template_content_format,
process_content_for_template_format,
)
from sglang.test.test_utils import CustomTestCase
......@@ -33,7 +33,7 @@ class TestTemplateContentFormatDetection(CustomTestCase):
{%- endfor %}
"""
result = detect_template_content_format(llama4_pattern)
result = detect_jinja_template_content_format(llama4_pattern)
self.assertEqual(result, "openai")
def test_detect_deepseek_string_format(self):
......@@ -46,19 +46,19 @@ class TestTemplateContentFormatDetection(CustomTestCase):
{%- endfor %}
"""
result = detect_template_content_format(deepseek_pattern)
result = detect_jinja_template_content_format(deepseek_pattern)
self.assertEqual(result, "string")
def test_detect_invalid_template(self):
"""Test handling of invalid template (should default to 'string')."""
invalid_pattern = "{{{{ invalid jinja syntax }}}}"
result = detect_template_content_format(invalid_pattern)
result = detect_jinja_template_content_format(invalid_pattern)
self.assertEqual(result, "string")
def test_detect_empty_template(self):
"""Test handling of empty template (should default to 'string')."""
result = detect_template_content_format("")
result = detect_jinja_template_content_format("")
self.assertEqual(result, "string")
def test_process_content_openai_format(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment