Unverified Commit 72676cd6 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

feat(oai refactor): Replace `openai_api` with `entrypoints/openai` (#7351)


Co-authored-by: default avatarJin Pan <jpan236@wisc.edu>
parent 02bf31ef
...@@ -20,7 +20,7 @@ from sglang.bench_serving import ( ...@@ -20,7 +20,7 @@ from sglang.bench_serving import (
get_gen_prefix_cache_path, get_gen_prefix_cache_path,
) )
from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path
from sglang.srt.openai_api.protocol import ChatCompletionMessageContentPart from sglang.srt.entrypoints.openai.protocol import ChatCompletionMessageContentPart
from sglang.utils import encode_video_base64 from sglang.utils import encode_video_base64
# type of content fields, can be only prompts or with images/videos # type of content fields, can be only prompts or with images/videos
......
...@@ -64,11 +64,14 @@ ...@@ -64,11 +64,14 @@
"text = \"Once upon a time\"\n", "text = \"Once upon a time\"\n",
"\n", "\n",
"curl_text = f\"\"\"curl -s http://localhost:{port}/v1/embeddings \\\n", "curl_text = f\"\"\"curl -s http://localhost:{port}/v1/embeddings \\\n",
" -H \"Content-Type: application/json\" \\\n",
" -d '{{\"model\": \"Alibaba-NLP/gte-Qwen2-1.5B-instruct\", \"input\": \"{text}\"}}'\"\"\"\n", " -d '{{\"model\": \"Alibaba-NLP/gte-Qwen2-1.5B-instruct\", \"input\": \"{text}\"}}'\"\"\"\n",
"\n", "\n",
"text_embedding = json.loads(subprocess.check_output(curl_text, shell=True))[\"data\"][0][\n", "result = subprocess.check_output(curl_text, shell=True)\n",
" \"embedding\"\n", "\n",
"]\n", "print(result)\n",
"\n",
"text_embedding = json.loads(result)[\"data\"][0][\"embedding\"]\n",
"\n", "\n",
"print_highlight(f\"Text embedding (first 10): {text_embedding[:10]}\")" "print_highlight(f\"Text embedding (first 10): {text_embedding[:10]}\")"
] ]
...@@ -152,6 +155,7 @@ ...@@ -152,6 +155,7 @@
"input_ids = tokenizer.encode(text)\n", "input_ids = tokenizer.encode(text)\n",
"\n", "\n",
"curl_ids = f\"\"\"curl -s http://localhost:{port}/v1/embeddings \\\n", "curl_ids = f\"\"\"curl -s http://localhost:{port}/v1/embeddings \\\n",
" -H \"Content-Type: application/json\" \\\n",
" -d '{{\"model\": \"Alibaba-NLP/gte-Qwen2-1.5B-instruct\", \"input\": {json.dumps(input_ids)}}}'\"\"\"\n", " -d '{{\"model\": \"Alibaba-NLP/gte-Qwen2-1.5B-instruct\", \"input\": {json.dumps(input_ids)}}}'\"\"\"\n",
"\n", "\n",
"input_ids_embedding = json.loads(subprocess.check_output(curl_ids, shell=True))[\"data\"][\n", "input_ids_embedding = json.loads(subprocess.check_output(curl_ids, shell=True))[\"data\"][\n",
......
...@@ -67,6 +67,7 @@ ...@@ -67,6 +67,7 @@
"\n", "\n",
"curl_command = f\"\"\"\n", "curl_command = f\"\"\"\n",
"curl -s http://localhost:{port}/v1/chat/completions \\\\\n", "curl -s http://localhost:{port}/v1/chat/completions \\\\\n",
" -H \"Content-Type: application/json\" \\\\\n",
" -d '{{\n", " -d '{{\n",
" \"model\": \"Qwen/Qwen2.5-VL-7B-Instruct\",\n", " \"model\": \"Qwen/Qwen2.5-VL-7B-Instruct\",\n",
" \"messages\": [\n", " \"messages\": [\n",
......
...@@ -36,7 +36,7 @@ ...@@ -36,7 +36,7 @@
"import requests\n", "import requests\n",
"from PIL import Image\n", "from PIL import Image\n",
"\n", "\n",
"from sglang.srt.openai_api.protocol import ChatCompletionRequest\n", "from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest\n",
"from sglang.srt.conversation import chat_templates\n", "from sglang.srt.conversation import chat_templates\n",
"\n", "\n",
"image = Image.open(\n", "image = Image.open(\n",
......
...@@ -15,9 +15,7 @@ ...@@ -15,9 +15,7 @@
import dataclasses import dataclasses
import json
import logging import logging
import os
from enum import auto from enum import auto
from sglang.srt.entrypoints.openai.protocol import CompletionRequest from sglang.srt.entrypoints.openai.protocol import CompletionRequest
...@@ -57,46 +55,6 @@ class CompletionTemplate: ...@@ -57,46 +55,6 @@ class CompletionTemplate:
completion_templates: dict[str, CompletionTemplate] = {} completion_templates: dict[str, CompletionTemplate] = {}
def load_completion_template_for_openai_api(completion_template_arg):
global completion_template_name
logger.info(
f"Use completion template for the OpenAI-compatible API server: {completion_template_arg}"
)
if not completion_template_exists(completion_template_arg):
if not os.path.exists(completion_template_arg):
raise RuntimeError(
f"Completion template {completion_template_arg} is not a built-in template name "
"or a valid completion template file path."
)
assert completion_template_arg.endswith(
".json"
), "unrecognized format of completion template file"
with open(completion_template_arg, "r") as filep:
template = json.load(filep)
try:
fim_position = FimPosition[template["fim_position"]]
except KeyError:
raise ValueError(
f"Unknown fim position: {template['fim_position']}"
) from None
register_completion_template(
CompletionTemplate(
name=template["name"],
fim_begin_token=template["fim_begin_token"],
fim_middle_token=template["fim_middle_token"],
fim_end_token=template["fim_end_token"],
fim_position=fim_position,
),
override=True,
)
completion_template_name = template["name"]
else:
completion_template_name = completion_template_arg
def register_completion_template(template: CompletionTemplate, override: bool = False): def register_completion_template(template: CompletionTemplate, override: bool = False):
"""Register a new completion template.""" """Register a new completion template."""
if not override: if not override:
......
...@@ -11,7 +11,17 @@ ...@@ -11,7 +11,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Conversation chat templates.""" """Conversation chat templates.
This module provides conversation template definitions, data structures, and utilities
for managing chat templates across different model types in SGLang.
Key components:
- Conversation class: Defines the structure and behavior of chat templates
- SeparatorStyle enum: Different conversation formatting styles
- Template registry: Functions to register and retrieve templates by name or model path
- Built-in templates: Pre-defined templates for popular models
"""
# Adapted from # Adapted from
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py # https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
...@@ -20,7 +30,7 @@ import re ...@@ -20,7 +30,7 @@ import re
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import Callable, Dict, List, Optional, Tuple, Union from typing import Callable, Dict, List, Optional, Tuple, Union
from sglang.srt.openai_api.protocol import ChatCompletionRequest from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
from sglang.srt.utils import read_system_prompt_from_file from sglang.srt.utils import read_system_prompt_from_file
...@@ -618,7 +628,7 @@ def generate_chat_conv( ...@@ -618,7 +628,7 @@ def generate_chat_conv(
# llama2 template # llama2 template
# reference: https://huggingface.co/blog/codellama#conversational-instructions # reference: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
# reference: https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/generation.py#L212 # reference: https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/generation.py#L212
register_conv_template( register_conv_template(
Conversation( Conversation(
......
...@@ -37,7 +37,6 @@ setattr(threading, "_register_atexit", lambda *args, **kwargs: None) ...@@ -37,7 +37,6 @@ setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
import torch import torch
import uvloop import uvloop
from sglang.srt.code_completion_parser import load_completion_template_for_openai_api
from sglang.srt.entrypoints.EngineBase import EngineBase from sglang.srt.entrypoints.EngineBase import EngineBase
from sglang.srt.managers.data_parallel_controller import ( from sglang.srt.managers.data_parallel_controller import (
run_data_parallel_controller_process, run_data_parallel_controller_process,
...@@ -58,11 +57,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -58,11 +57,8 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromTensorReqInput, UpdateWeightsFromTensorReqInput,
) )
from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.openai_api.adapter import (
guess_chat_template_name_from_model_path,
load_chat_template_for_openai_api,
)
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import ( from sglang.srt.utils import (
...@@ -123,12 +119,13 @@ class Engine(EngineBase): ...@@ -123,12 +119,13 @@ class Engine(EngineBase):
logger.info(f"{server_args=}") logger.info(f"{server_args=}")
# Launch subprocesses # Launch subprocesses
tokenizer_manager, scheduler_info = _launch_subprocesses( tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
server_args=server_args, server_args=server_args,
port_args=port_args, port_args=port_args,
) )
self.server_args = server_args self.server_args = server_args
self.tokenizer_manager = tokenizer_manager self.tokenizer_manager = tokenizer_manager
self.template_manager = template_manager
self.scheduler_info = scheduler_info self.scheduler_info = scheduler_info
context = zmq.Context(2) context = zmq.Context(2)
...@@ -647,7 +644,7 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -647,7 +644,7 @@ def _set_envs_and_config(server_args: ServerArgs):
def _launch_subprocesses( def _launch_subprocesses(
server_args: ServerArgs, port_args: Optional[PortArgs] = None server_args: ServerArgs, port_args: Optional[PortArgs] = None
) -> Tuple[TokenizerManager, Dict]: ) -> Tuple[TokenizerManager, TemplateManager, Dict]:
""" """
Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess. Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
""" """
...@@ -732,7 +729,7 @@ def _launch_subprocesses( ...@@ -732,7 +729,7 @@ def _launch_subprocesses(
if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0": if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0":
# When using `Engine` as a Python API, we don't want to block here. # When using `Engine` as a Python API, we don't want to block here.
return None, None return None, None, None
launch_dummy_health_check_server(server_args.host, server_args.port) launch_dummy_health_check_server(server_args.host, server_args.port)
...@@ -741,7 +738,7 @@ def _launch_subprocesses( ...@@ -741,7 +738,7 @@ def _launch_subprocesses(
logger.error( logger.error(
f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}" f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
) )
return None, None return None, None, None
# Launch detokenizer process # Launch detokenizer process
detoken_proc = mp.Process( detoken_proc = mp.Process(
...@@ -755,15 +752,15 @@ def _launch_subprocesses( ...@@ -755,15 +752,15 @@ def _launch_subprocesses(
# Launch tokenizer process # Launch tokenizer process
tokenizer_manager = TokenizerManager(server_args, port_args) tokenizer_manager = TokenizerManager(server_args, port_args)
if server_args.chat_template:
load_chat_template_for_openai_api(
tokenizer_manager, server_args.chat_template, server_args.model_path
)
else:
guess_chat_template_name_from_model_path(server_args.model_path)
if server_args.completion_template: # Initialize templates
load_completion_template_for_openai_api(server_args.completion_template) template_manager = TemplateManager()
template_manager.initialize_templates(
tokenizer_manager=tokenizer_manager,
model_path=server_args.model_path,
chat_template=server_args.chat_template,
completion_template=server_args.completion_template,
)
# Wait for the model to finish loading # Wait for the model to finish loading
scheduler_infos = [] scheduler_infos = []
...@@ -787,4 +784,4 @@ def _launch_subprocesses( ...@@ -787,4 +784,4 @@ def _launch_subprocesses(
# Assume all schedulers have the same scheduler_info # Assume all schedulers have the same scheduler_info
scheduler_info = scheduler_infos[0] scheduler_info = scheduler_infos[0]
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"] tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
return tokenizer_manager, scheduler_info return tokenizer_manager, template_manager, scheduler_info
...@@ -38,7 +38,8 @@ import orjson ...@@ -38,7 +38,8 @@ import orjson
import requests import requests
import uvicorn import uvicorn
import uvloop import uvloop
from fastapi import FastAPI, File, Form, Request, UploadFile from fastapi import Depends, FastAPI, Request, UploadFile
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import ORJSONResponse, Response, StreamingResponse from fastapi.responses import ORJSONResponse, Response, StreamingResponse
...@@ -47,6 +48,20 @@ from sglang.srt.disaggregation.utils import ( ...@@ -47,6 +48,20 @@ from sglang.srt.disaggregation.utils import (
register_disaggregation_server, register_disaggregation_server,
) )
from sglang.srt.entrypoints.engine import _launch_subprocesses from sglang.srt.entrypoints.engine import _launch_subprocesses
from sglang.srt.entrypoints.openai.protocol import (
ChatCompletionRequest,
CompletionRequest,
EmbeddingRequest,
ModelCard,
ModelList,
ScoringRequest,
V1RerankReqInput,
)
from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from sglang.srt.entrypoints.openai.serving_rerank import OpenAIServingRerank
from sglang.srt.entrypoints.openai.serving_score import OpenAIServingScore
from sglang.srt.function_call.function_call_parser import FunctionCallParser from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
...@@ -67,26 +82,11 @@ from sglang.srt.managers.io_struct import ( ...@@ -67,26 +82,11 @@ from sglang.srt.managers.io_struct import (
UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput, UpdateWeightsFromTensorReqInput,
V1RerankReqInput,
VertexGenerateReqInput, VertexGenerateReqInput,
) )
from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.metrics.func_timer import enable_func_timer from sglang.srt.metrics.func_timer import enable_func_timer
from sglang.srt.openai_api.adapter import (
v1_batches,
v1_cancel_batch,
v1_chat_completions,
v1_completions,
v1_delete_file,
v1_embeddings,
v1_files_create,
v1_rerank,
v1_retrieve_batch,
v1_retrieve_file,
v1_retrieve_file_content,
v1_score,
)
from sglang.srt.openai_api.protocol import ModelCard, ModelList
from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.reasoning_parser import ReasoningParser
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
...@@ -109,6 +109,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) ...@@ -109,6 +109,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@dataclasses.dataclass @dataclasses.dataclass
class _GlobalState: class _GlobalState:
tokenizer_manager: TokenizerManager tokenizer_manager: TokenizerManager
template_manager: TemplateManager
scheduler_info: Dict scheduler_info: Dict
...@@ -123,6 +124,24 @@ def set_global_state(global_state: _GlobalState): ...@@ -123,6 +124,24 @@ def set_global_state(global_state: _GlobalState):
@asynccontextmanager @asynccontextmanager
async def lifespan(fast_api_app: FastAPI): async def lifespan(fast_api_app: FastAPI):
server_args: ServerArgs = fast_api_app.server_args server_args: ServerArgs = fast_api_app.server_args
# Initialize OpenAI serving handlers
fast_api_app.state.openai_serving_completion = OpenAIServingCompletion(
_global_state.tokenizer_manager, _global_state.template_manager
)
fast_api_app.state.openai_serving_chat = OpenAIServingChat(
_global_state.tokenizer_manager, _global_state.template_manager
)
fast_api_app.state.openai_serving_embedding = OpenAIServingEmbedding(
_global_state.tokenizer_manager, _global_state.template_manager
)
fast_api_app.state.openai_serving_score = OpenAIServingScore(
_global_state.tokenizer_manager
)
fast_api_app.state.openai_serving_rerank = OpenAIServingRerank(
_global_state.tokenizer_manager
)
if server_args.warmups is not None: if server_args.warmups is not None:
await execute_warmups( await execute_warmups(
server_args.warmups.split(","), _global_state.tokenizer_manager server_args.warmups.split(","), _global_state.tokenizer_manager
...@@ -148,6 +167,36 @@ app.add_middleware( ...@@ -148,6 +167,36 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
# Custom exception handlers to change validation error status codes
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""Override FastAPI's default 422 validation error with 400"""
return ORJSONResponse(
status_code=400,
content={
"detail": exc.errors(),
"body": exc.body,
},
)
async def validate_json_request(raw_request: Request):
"""Validate that the request content-type is application/json."""
content_type = raw_request.headers.get("content-type", "").lower()
media_type = content_type.split(";", maxsplit=1)[0]
if media_type != "application/json":
raise RequestValidationError(
errors=[
{
"loc": ["header", "content-type"],
"msg": "Unsupported Media Type: Only 'application/json' is allowed",
"type": "value_error",
}
]
)
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20)) HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
...@@ -330,13 +379,14 @@ async def classify_request(obj: EmbeddingReqInput, request: Request): ...@@ -330,13 +379,14 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
return _create_error_response(e) return _create_error_response(e)
@app.api_route("/v1/rerank", methods=["POST", "PUT"]) @app.api_route(
async def v1_rerank_request(obj: V1RerankReqInput, raw_request: Request): "/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
try: )
ret = await v1_rerank(_global_state.tokenizer_manager, obj, raw_request) async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request):
return ret """Endpoint for reranking documents based on query relevance."""
except ValueError as e: return await raw_request.app.state.openai_serving_rerank.handle_request(
return _create_error_response(e) request, raw_request
)
@app.api_route("/flush_cache", methods=["GET", "POST"]) @app.api_route("/flush_cache", methods=["GET", "POST"])
...@@ -619,25 +669,39 @@ async def separate_reasoning_request(obj: SeparateReasoningReqInput, request: Re ...@@ -619,25 +669,39 @@ async def separate_reasoning_request(obj: SeparateReasoningReqInput, request: Re
##### OpenAI-compatible API endpoints ##### ##### OpenAI-compatible API endpoints #####
@app.post("/v1/completions") @app.post("/v1/completions", dependencies=[Depends(validate_json_request)])
async def openai_v1_completions(raw_request: Request): async def openai_v1_completions(request: CompletionRequest, raw_request: Request):
return await v1_completions(_global_state.tokenizer_manager, raw_request) """OpenAI-compatible text completion endpoint."""
return await raw_request.app.state.openai_serving_completion.handle_request(
request, raw_request
)
@app.post("/v1/chat/completions") @app.post("/v1/chat/completions", dependencies=[Depends(validate_json_request)])
async def openai_v1_chat_completions(raw_request: Request): async def openai_v1_chat_completions(
return await v1_chat_completions(_global_state.tokenizer_manager, raw_request) request: ChatCompletionRequest, raw_request: Request
):
"""OpenAI-compatible chat completion endpoint."""
return await raw_request.app.state.openai_serving_chat.handle_request(
request, raw_request
)
@app.post("/v1/embeddings", response_class=ORJSONResponse) @app.post(
async def openai_v1_embeddings(raw_request: Request): "/v1/embeddings",
response = await v1_embeddings(_global_state.tokenizer_manager, raw_request) response_class=ORJSONResponse,
return response dependencies=[Depends(validate_json_request)],
)
async def openai_v1_embeddings(request: EmbeddingRequest, raw_request: Request):
"""OpenAI-compatible embeddings endpoint."""
return await raw_request.app.state.openai_serving_embedding.handle_request(
request, raw_request
)
@app.get("/v1/models", response_class=ORJSONResponse) @app.get("/v1/models", response_class=ORJSONResponse)
def available_models(): async def available_models():
"""Show available models.""" """Show available models. OpenAI-compatible endpoint."""
served_model_names = [_global_state.tokenizer_manager.served_model_name] served_model_names = [_global_state.tokenizer_manager.served_model_name]
model_cards = [] model_cards = []
for served_model_name in served_model_names: for served_model_name in served_model_names:
...@@ -651,45 +715,29 @@ def available_models(): ...@@ -651,45 +715,29 @@ def available_models():
return ModelList(data=model_cards) return ModelList(data=model_cards)
@app.post("/v1/files") @app.get("/v1/models/{model:path}", response_class=ORJSONResponse)
async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")): async def retrieve_model(model: str):
return await v1_files_create( """Retrieves a model instance, providing basic information about the model."""
file, purpose, _global_state.tokenizer_manager.server_args.file_storage_path served_model_names = [_global_state.tokenizer_manager.served_model_name]
)
@app.delete("/v1/files/{file_id}")
async def delete_file(file_id: str):
# https://platform.openai.com/docs/api-reference/files/delete
return await v1_delete_file(file_id)
@app.post("/v1/batches")
async def openai_v1_batches(raw_request: Request):
return await v1_batches(_global_state.tokenizer_manager, raw_request)
@app.post("/v1/batches/{batch_id}/cancel")
async def cancel_batches(batch_id: str):
# https://platform.openai.com/docs/api-reference/batch/cancel
return await v1_cancel_batch(_global_state.tokenizer_manager, batch_id)
@app.get("/v1/batches/{batch_id}")
async def retrieve_batch(batch_id: str):
return await v1_retrieve_batch(batch_id)
@app.get("/v1/files/{file_id}")
async def retrieve_file(file_id: str):
# https://platform.openai.com/docs/api-reference/files/retrieve
return await v1_retrieve_file(file_id)
if model not in served_model_names:
return ORJSONResponse(
status_code=404,
content={
"error": {
"message": f"The model '{model}' does not exist",
"type": "invalid_request_error",
"param": "model",
"code": "model_not_found",
}
},
)
@app.get("/v1/files/{file_id}/content") return ModelCard(
async def retrieve_file_content(file_id: str): id=model,
# https://platform.openai.com/docs/api-reference/files/retrieve-contents root=model,
return await v1_retrieve_file_content(file_id) max_model_len=_global_state.tokenizer_manager.model_config.context_len,
)
## SageMaker API ## SageMaker API
...@@ -700,8 +748,13 @@ async def sagemaker_health() -> Response: ...@@ -700,8 +748,13 @@ async def sagemaker_health() -> Response:
@app.post("/invocations") @app.post("/invocations")
async def sagemaker_chat_completions(raw_request: Request): async def sagemaker_chat_completions(
return await v1_chat_completions(_global_state.tokenizer_manager, raw_request) request: ChatCompletionRequest, raw_request: Request
):
"""OpenAI-compatible chat completion endpoint."""
return await raw_request.app.state.openai_serving_chat.handle_request(
request, raw_request
)
## Vertex AI API ## Vertex AI API
...@@ -732,10 +785,12 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque ...@@ -732,10 +785,12 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque
return ORJSONResponse({"predictions": ret}) return ORJSONResponse({"predictions": ret})
@app.post("/v1/score") @app.post("/v1/score", dependencies=[Depends(validate_json_request)])
async def v1_score_request(raw_request: Request): async def v1_score_request(request: ScoringRequest, raw_request: Request):
"""Endpoint for the decoder-only scoring API. See Engine.score() for detailed documentation.""" """Endpoint for the decoder-only scoring API. See Engine.score() for detailed documentation."""
return await v1_score(_global_state.tokenizer_manager, raw_request) return await raw_request.app.state.openai_serving_score.handle_request(
request, raw_request
)
def _create_error_response(e): def _create_error_response(e):
...@@ -764,10 +819,13 @@ def launch_server( ...@@ -764,10 +819,13 @@ def launch_server(
1. The HTTP server, Engine, and TokenizerManager both run in the main process. 1. The HTTP server, Engine, and TokenizerManager both run in the main process.
2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library. 2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library.
""" """
tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args) tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
server_args=server_args
)
set_global_state( set_global_state(
_GlobalState( _GlobalState(
tokenizer_manager=tokenizer_manager, tokenizer_manager=tokenizer_manager,
template_manager=template_manager,
scheduler_info=scheduler_info, scheduler_info=scheduler_info,
) )
) )
......
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
SGLang OpenAI-Compatible API Server.
This file implements OpenAI-compatible HTTP APIs for the inference engine via FastAPI.
"""
import argparse
import asyncio
import logging
import multiprocessing
import os
import threading
import time
from contextlib import asynccontextmanager
from typing import Callable, Dict, Optional
import numpy as np
import requests
import uvicorn
import uvloop
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response
from sglang.srt.disaggregation.utils import (
FAKE_BOOTSTRAP_HOST,
register_disaggregation_server,
)
from sglang.srt.entrypoints.engine import Engine, _launch_subprocesses
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.metrics.func_timer import enable_func_timer
from sglang.srt.openai_api.protocol import EmbeddingRequest, ModelCard, ModelList
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
add_prometheus_middleware,
delete_directory,
get_bool_env_var,
kill_process_tree,
set_uvicorn_logging_configs,
)
from sglang.srt.warmup import execute_warmups
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
# Store global states
class AppState:
engine: Optional[Engine] = None
server_args: Optional[ServerArgs] = None
tokenizer_manager: Optional[TokenizerManager] = None
scheduler_info: Optional[Dict] = None
embedding_server: Optional[OpenAIServingEmbedding] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
app.state.server_args.enable_metrics = True # By default, we enable metrics
server_args = app.state.server_args
# Initialize engine
logger.info(f"SGLang OpenAI server (PID: {os.getpid()}) is initializing...")
tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args)
app.state.tokenizer_manager = tokenizer_manager
app.state.scheduler_info = scheduler_info
app.state.serving_embedding = OpenAIServingEmbedding(
tokenizer_manager=tokenizer_manager
)
if server_args.enable_metrics:
add_prometheus_middleware(app)
enable_func_timer()
# Initialize engine state attribute to None for now
app.state.engine = None
if server_args.warmups is not None:
await execute_warmups(
server_args.warmups.split(","), app.state.tokenizer_manager
)
logger.info("Warmup ended")
warmup_thread = getattr(app, "warmup_thread", None)
if warmup_thread is not None:
warmup_thread.start()
yield
# Lifespan shutdown
if hasattr(app.state, "engine") and app.state.engine is not None:
logger.info("SGLang engine is shutting down.")
# Add engine cleanup logic here when implemented
# Fast API app with CORS enabled
app = FastAPI(
lifespan=lifespan,
# TODO: check where /openai.json is created or why we use this
openapi_url=None if get_bool_env_var("DISABLE_OPENAPI_DOC") else "/openapi.json",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.api_route("/health", methods=["GET"])
async def health() -> Response:
"""Health check. Used for readiness and liveness probes."""
# In the future, this could check engine health more deeply
# For now, if the server is up, it's healthy.
return Response(status_code=200)
@app.api_route("/v1/models", methods=["GET"])
async def show_models():
"""Show available models. Currently, it returns the served model name.
This endpoint is compatible with the OpenAI API standard.
"""
served_model_names = [app.state.tokenizer_manager.served_model_name]
model_cards = []
for served_model_name in served_model_names:
model_cards.append(
ModelCard(
id=served_model_name,
root=served_model_name,
max_model_len=app.state.tokenizer_manager.model_config.context_len,
)
)
return ModelList(data=model_cards)
@app.get("/get_model_info")
async def get_model_info():
"""Get the model information."""
result = {
"model_path": app.state.tokenizer_manager.model_path,
"tokenizer_path": app.state.tokenizer_manager.server_args.tokenizer_path,
"is_generation": app.state.tokenizer_manager.is_generation,
}
return result
@app.post("/v1/completions")
async def openai_v1_completions(raw_request: Request):
pass
@app.post("/v1/chat/completions")
async def openai_v1_chat_completions(raw_request: Request):
pass
@app.post("/v1/embeddings")
async def openai_v1_embeddings(raw_request: Request):
try:
request_json = await raw_request.json()
request = EmbeddingRequest(**request_json)
except Exception as e:
return app.state.serving_embedding.create_error_response(
f"Invalid request body, error: {str(e)}"
)
ret = await app.state.serving_embedding.handle_request(request, raw_request)
return ret
@app.post("/v1/score")
async def v1_score_request(raw_request: Request):
"""Endpoint for the decoder-only scoring API. See Engine.score() for detailed documentation."""
pass
@app.api_route("/v1/models/{model_id}", methods=["GET"])
async def show_model_detail(model_id: str):
served_model_name = app.state.tokenizer_manager.served_model_name
return ModelCard(
id=served_model_name,
root=served_model_name,
max_model_len=app.state.tokenizer_manager.model_config.context_len,
)
# Additional API endpoints will be implemented in separate serving_*.py modules
# and mounted as APIRouters in future PRs
def _wait_and_warmup(
server_args: ServerArgs,
pipe_finish_writer: Optional[multiprocessing.connection.Connection],
image_token_text: str,
launch_callback: Optional[Callable[[], None]] = None,
):
return
# TODO: Please wait until the /generate implementation is complete,
# or confirm if modifications are needed before removing this.
headers = {}
url = server_args.url()
if server_args.api_key:
headers["Authorization"] = f"Bearer {server_args.api_key}"
# Wait until the server is launched
success = False
for _ in range(120):
time.sleep(1)
try:
res = requests.get(url + "/get_model_info", timeout=5, headers=headers)
assert res.status_code == 200, f"{res=}, {res.text=}"
success = True
break
except (AssertionError, requests.exceptions.RequestException):
last_traceback = get_exception_traceback()
pass
if not success:
if pipe_finish_writer is not None:
pipe_finish_writer.send(last_traceback)
logger.error(f"Initialization failed. warmup error: {last_traceback}")
kill_process_tree(os.getpid())
return
model_info = res.json()
# Send a warmup request
request_name = "/generate" if model_info["is_generation"] else "/encode"
# TODO: Replace with OpenAI API
max_new_tokens = 8 if model_info["is_generation"] else 1
json_data = {
"sampling_params": {
"temperature": 0,
"max_new_tokens": max_new_tokens,
},
}
if server_args.skip_tokenizer_init:
json_data["input_ids"] = [[10, 11, 12] for _ in range(server_args.dp_size)]
# TODO Workaround the bug that embedding errors for list of size 1
if server_args.dp_size == 1:
json_data["input_ids"] = json_data["input_ids"][0]
else:
json_data["text"] = ["The capital city of France is"] * server_args.dp_size
# TODO Workaround the bug that embedding errors for list of size 1
if server_args.dp_size == 1:
json_data["text"] = json_data["text"][0]
# Debug dumping
if server_args.debug_tensor_dump_input_file:
json_data.pop("text", None)
json_data["input_ids"] = np.load(
server_args.debug_tensor_dump_input_file
).tolist()
json_data["sampling_params"]["max_new_tokens"] = 0
try:
if server_args.disaggregation_mode == "null":
res = requests.post(
url + request_name,
json=json_data,
headers=headers,
timeout=600,
)
assert res.status_code == 200, f"{res}"
else:
logger.info(f"Start of prefill warmup ...")
json_data = {
"sampling_params": {
"temperature": 0.0,
"max_new_tokens": 8,
"ignore_eos": True,
},
"bootstrap_host": [FAKE_BOOTSTRAP_HOST] * server_args.dp_size,
# This is a hack to ensure fake transfer is enabled during prefill warmup
# ensure each dp rank has a unique bootstrap_room during prefill warmup
"bootstrap_room": [
i * (2**63 // server_args.dp_size) + (i % server_args.tp_size)
for i in range(server_args.dp_size)
],
"input_ids": [[0, 1, 2, 3]] * server_args.dp_size,
}
res = requests.post(
url + request_name,
json=json_data,
headers=headers,
timeout=1800, # because of deep gemm precache is very long if not precache.
)
logger.info(
f"End of prefill warmup with status {res.status_code}, resp: {res.json()}"
)
except Exception:
last_traceback = get_exception_traceback()
if pipe_finish_writer is not None:
pipe_finish_writer.send(last_traceback)
logger.error(f"Initialization failed. warmup error: {last_traceback}")
kill_process_tree(os.getpid())
return
# Debug print
# logger.info(f"{res.json()=}")
logger.info("The server is fired up and ready to roll!")
if pipe_finish_writer is not None:
pipe_finish_writer.send("ready")
if server_args.delete_ckpt_after_loading:
delete_directory(server_args.model_path)
if server_args.debug_tensor_dump_input_file:
kill_process_tree(os.getpid())
if server_args.pdlb_url is not None:
register_disaggregation_server(
server_args.disaggregation_mode,
server_args.port,
server_args.disaggregation_bootstrap_port,
server_args.pdlb_url,
)
if launch_callback is not None:
launch_callback()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="SGLang OpenAI-Compatible API Server")
# Add arguments from ServerArgs. This allows reuse of existing CLI definitions.
ServerArgs.add_cli_args(parser)
# Potentially add server-specific arguments here in the future if needed
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
# Store server_args in app.state for access in lifespan and endpoints
app.state.server_args = server_args
# Configure logging
logging.basicConfig(
level=server_args.log_level.upper(),
format="%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s",
)
# Send a warmup request - we will create the thread launch it
# in the lifespan after all other warmups have fired.
warmup_thread = threading.Thread(
target=_wait_and_warmup,
args=(
server_args,
None,
None, # Never used
None,
),
)
app.warmup_thread = warmup_thread
try:
# Start the server
set_uvicorn_logging_configs()
uvicorn.run(
app,
host=server_args.host,
port=server_args.port,
log_level=server_args.log_level.lower(),
timeout_keep_alive=60, # Increased keep-alive for potentially long requests
loop="uvloop", # Use uvloop for better performance if available
)
finally:
warmup_thread.join()
...@@ -207,7 +207,7 @@ class CompletionResponseChoice(BaseModel): ...@@ -207,7 +207,7 @@ class CompletionResponseChoice(BaseModel):
index: int index: int
text: str text: str
logprobs: Optional[LogProbs] = None logprobs: Optional[LogProbs] = None
finish_reason: Literal["stop", "length", "content_filter", "abort"] finish_reason: Optional[Literal["stop", "length", "content_filter", "abort"]] = None
matched_stop: Union[None, int, str] = None matched_stop: Union[None, int, str] = None
hidden_states: Optional[object] = None hidden_states: Optional[object] = None
...@@ -404,7 +404,6 @@ class ChatCompletionRequest(BaseModel): ...@@ -404,7 +404,6 @@ class ChatCompletionRequest(BaseModel):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def set_tool_choice_default(cls, values): def set_tool_choice_default(cls, values):
if isinstance(values, dict):
if values.get("tool_choice") is None: if values.get("tool_choice") is None:
if values.get("tools") is None: if values.get("tools") is None:
values["tool_choice"] = "none" values["tool_choice"] = "none"
...@@ -412,13 +411,6 @@ class ChatCompletionRequest(BaseModel): ...@@ -412,13 +411,6 @@ class ChatCompletionRequest(BaseModel):
values["tool_choice"] = "auto" values["tool_choice"] = "auto"
return values return values
@field_validator("messages")
@classmethod
def validate_messages_not_empty(cls, v):
if not v:
raise ValueError("Messages cannot be empty")
return v
# Extra parameters for SRT backend only and will be ignored by OpenAI models. # Extra parameters for SRT backend only and will be ignored by OpenAI models.
top_k: int = -1 top_k: int = -1
min_p: float = 0.0 min_p: float = 0.0
...@@ -457,9 +449,11 @@ class ChatCompletionResponseChoice(BaseModel): ...@@ -457,9 +449,11 @@ class ChatCompletionResponseChoice(BaseModel):
index: int index: int
message: ChatMessage message: ChatMessage
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
finish_reason: Literal[ finish_reason: Optional[
Literal[
"stop", "length", "tool_calls", "content_filter", "function_call", "abort" "stop", "length", "tool_calls", "content_filter", "function_call", "abort"
] ]
] = None
matched_stop: Union[None, int, str] = None matched_stop: Union[None, int, str] = None
hidden_states: Optional[object] = None hidden_states: Optional[object] = None
...@@ -530,7 +524,7 @@ class EmbeddingRequest(BaseModel): ...@@ -530,7 +524,7 @@ class EmbeddingRequest(BaseModel):
input: EmbeddingInput input: EmbeddingInput
model: str model: str
encoding_format: str = "float" encoding_format: str = "float"
dimensions: int = None dimensions: Optional[int] = None
user: Optional[str] = None user: Optional[str] = None
# The request id. # The request id.
......
...@@ -2,16 +2,12 @@ import json ...@@ -2,16 +2,12 @@ import json
import logging import logging
import uuid import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Union from typing import Any, Optional, Union
from fastapi import Request from fastapi import Request
from fastapi.responses import ORJSONResponse, StreamingResponse from fastapi.responses import ORJSONResponse, StreamingResponse
from sglang.srt.entrypoints.openai.protocol import ( from sglang.srt.entrypoints.openai.protocol import ErrorResponse, OpenAIServingRequest
ErrorResponse,
OpenAIServingRequest,
UsageInfo,
)
from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
...@@ -51,7 +47,7 @@ class OpenAIServingBase(ABC): ...@@ -51,7 +47,7 @@ class OpenAIServingBase(ABC):
) )
except Exception as e: except Exception as e:
logger.error(f"Error in request: {e}") logger.exception(f"Error in request: {e}")
return self.create_error_response( return self.create_error_response(
message=f"Internal server error: {str(e)}", message=f"Internal server error: {str(e)}",
err_type="InternalServerError", err_type="InternalServerError",
...@@ -63,8 +59,12 @@ class OpenAIServingBase(ABC): ...@@ -63,8 +59,12 @@ class OpenAIServingBase(ABC):
"""Generate request ID based on request type""" """Generate request ID based on request type"""
pass pass
def _generate_request_id_base(self, request: OpenAIServingRequest) -> str: def _generate_request_id_base(self, request: OpenAIServingRequest) -> Optional[str]:
"""Generate request ID based on request type""" """Generate request ID based on request type"""
return None
# TODO(chang): the rid is used in io_strcut check and often violates `The rid should be a list` AssertionError
# Temporarily return None in this function until the rid logic is clear.
if rid := getattr(request, "rid", None): if rid := getattr(request, "rid", None):
return rid return rid
...@@ -83,7 +83,7 @@ class OpenAIServingBase(ABC): ...@@ -83,7 +83,7 @@ class OpenAIServingBase(ABC):
adapted_request: GenerateReqInput, adapted_request: GenerateReqInput,
request: OpenAIServingRequest, request: OpenAIServingRequest,
raw_request: Request, raw_request: Request,
) -> StreamingResponse: ) -> Union[StreamingResponse, ErrorResponse, ORJSONResponse]:
"""Handle streaming request """Handle streaming request
Override this method in child classes that support streaming requests. Override this method in child classes that support streaming requests.
...@@ -99,7 +99,7 @@ class OpenAIServingBase(ABC): ...@@ -99,7 +99,7 @@ class OpenAIServingBase(ABC):
adapted_request: GenerateReqInput, adapted_request: GenerateReqInput,
request: OpenAIServingRequest, request: OpenAIServingRequest,
raw_request: Request, raw_request: Request,
) -> Union[Any, ErrorResponse]: ) -> Union[Any, ErrorResponse, ORJSONResponse]:
"""Handle non-streaming request """Handle non-streaming request
Override this method in child classes that support non-streaming requests. Override this method in child classes that support non-streaming requests.
...@@ -110,7 +110,7 @@ class OpenAIServingBase(ABC): ...@@ -110,7 +110,7 @@ class OpenAIServingBase(ABC):
status_code=501, status_code=501,
) )
def _validate_request(self, request: OpenAIServingRequest) -> Optional[str]: def _validate_request(self, _: OpenAIServingRequest) -> Optional[str]:
"""Validate request""" """Validate request"""
pass pass
...@@ -122,6 +122,7 @@ class OpenAIServingBase(ABC): ...@@ -122,6 +122,7 @@ class OpenAIServingBase(ABC):
param: Optional[str] = None, param: Optional[str] = None,
) -> ORJSONResponse: ) -> ORJSONResponse:
"""Create an error response""" """Create an error response"""
# TODO: remove fastapi dependency in openai and move response handling to the entrypoint
error = ErrorResponse( error = ErrorResponse(
object="error", object="error",
message=message, message=message,
......
import base64
import json import json
import logging import logging
import time import time
...@@ -6,7 +5,7 @@ import uuid ...@@ -6,7 +5,7 @@ import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional, Union from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from fastapi import Request from fastapi import Request
from fastapi.responses import StreamingResponse from fastapi.responses import ORJSONResponse, StreamingResponse
from sglang.srt.conversation import generate_chat_conv from sglang.srt.conversation import generate_chat_conv
from sglang.srt.entrypoints.openai.protocol import ( from sglang.srt.entrypoints.openai.protocol import (
...@@ -28,13 +27,14 @@ from sglang.srt.entrypoints.openai.protocol import ( ...@@ -28,13 +27,14 @@ from sglang.srt.entrypoints.openai.protocol import (
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor
from sglang.srt.entrypoints.openai.utils import ( from sglang.srt.entrypoints.openai.utils import (
detect_template_content_format,
process_content_for_template_format,
process_hidden_states_from_ret, process_hidden_states_from_ret,
to_openai_style_logprobs, to_openai_style_logprobs,
) )
from sglang.srt.function_call.function_call_parser import FunctionCallParser from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.jinja_template_utils import process_content_for_template_format
from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.reasoning_parser import ReasoningParser
from sglang.utils import convert_json_schema_to_str from sglang.utils import convert_json_schema_to_str
...@@ -42,13 +42,13 @@ logger = logging.getLogger(__name__) ...@@ -42,13 +42,13 @@ logger = logging.getLogger(__name__)
class OpenAIServingChat(OpenAIServingBase): class OpenAIServingChat(OpenAIServingBase):
"""Handler for chat completion requests""" """Handler for /v1/chat/completions requests"""
def __init__(self, *args, **kwargs): def __init__(
super().__init__(*args, **kwargs) self, tokenizer_manager: TokenizerManager, template_manager: TemplateManager
# Instance-specific cache for template content format detection ):
self._cached_chat_template = None super().__init__(tokenizer_manager)
self._cached_template_format = None self.template_manager = template_manager
def _request_id_prefix(self) -> str: def _request_id_prefix(self) -> str:
return "chatcmpl-" return "chatcmpl-"
...@@ -142,19 +142,14 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -142,19 +142,14 @@ class OpenAIServingChat(OpenAIServingBase):
) )
# Use chat template # Use chat template
if ( if self.template_manager.chat_template_name is None:
hasattr(self.tokenizer_manager, "chat_template_name")
and self.tokenizer_manager.chat_template_name is None
):
prompt, prompt_ids, image_data, audio_data, modalities, stop = ( prompt, prompt_ids, image_data, audio_data, modalities, stop = (
self._apply_jinja_template(request, tools, is_multimodal) self._apply_jinja_template(request, tools, is_multimodal)
) )
else: else:
prompt, image_data, audio_data, modalities, stop = ( prompt, prompt_ids, image_data, audio_data, modalities, stop = (
self._apply_conversation_template(request) self._apply_conversation_template(request, is_multimodal)
) )
if not is_multimodal:
prompt_ids = self.tokenizer_manager.tokenizer.encode(prompt)
else: else:
# Use raw prompt # Use raw prompt
prompt_ids = request.messages prompt_ids = request.messages
...@@ -181,23 +176,14 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -181,23 +176,14 @@ class OpenAIServingChat(OpenAIServingBase):
is_multimodal: bool, is_multimodal: bool,
) -> tuple[str, List[int], Optional[Any], Optional[Any], List[str], List[str]]: ) -> tuple[str, List[int], Optional[Any], Optional[Any], List[str], List[str]]:
"""Apply Jinja chat template""" """Apply Jinja chat template"""
prompt = ""
prompt_ids = []
openai_compatible_messages = [] openai_compatible_messages = []
image_data = [] image_data = []
audio_data = [] audio_data = []
modalities = [] modalities = []
# Detect template content format template_content_format = self.template_manager.jinja_template_content_format
current_template = self.tokenizer_manager.tokenizer.chat_template
if current_template != self._cached_chat_template:
self._cached_chat_template = current_template
self._cached_template_format = detect_template_content_format(
current_template
)
logger.info(
f"Detected chat template content format: {self._cached_template_format}"
)
template_content_format = self._cached_template_format
for message in request.messages: for message in request.messages:
if message.content is None: if message.content is None:
...@@ -262,14 +248,21 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -262,14 +248,21 @@ class OpenAIServingChat(OpenAIServingBase):
if is_multimodal: if is_multimodal:
prompt = self.tokenizer_manager.tokenizer.decode(prompt_ids) prompt = self.tokenizer_manager.tokenizer.decode(prompt_ids)
stop = request.stop or [] stop = request.stop
image_data = image_data if image_data else None
audio_data = audio_data if audio_data else None
modalities = modalities if modalities else []
return prompt, prompt_ids, image_data, audio_data, modalities, stop return prompt, prompt_ids, image_data, audio_data, modalities, stop
def _apply_conversation_template( def _apply_conversation_template(
self, request: ChatCompletionRequest self,
) -> tuple[str, Optional[Any], Optional[Any], List[str], List[str]]: request: ChatCompletionRequest,
is_multimodal: bool,
) -> tuple[str, Optional[Any], Optional[Any], List[str], List[str], List[str]]:
"""Apply conversation template""" """Apply conversation template"""
conv = generate_chat_conv(request, self.tokenizer_manager.chat_template_name) prompt = ""
prompt_ids = []
conv = generate_chat_conv(request, self.template_manager.chat_template_name)
# If we should continue the final assistant message, adjust the conversation. # If we should continue the final assistant message, adjust the conversation.
if ( if (
...@@ -296,9 +289,9 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -296,9 +289,9 @@ class OpenAIServingChat(OpenAIServingBase):
else: else:
prompt = conv.get_prompt() prompt = conv.get_prompt()
image_data = conv.image_data image_data = conv.image_data if conv.image_data else None
audio_data = conv.audio_data audio_data = conv.audio_data if conv.audio_data else None
modalities = conv.modalities modalities = conv.modalities if conv.modalities else []
stop = conv.stop_str or [] if not request.ignore_eos else [] stop = conv.stop_str or [] if not request.ignore_eos else []
if request.stop: if request.stop:
...@@ -307,7 +300,10 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -307,7 +300,10 @@ class OpenAIServingChat(OpenAIServingBase):
else: else:
stop.extend(request.stop) stop.extend(request.stop)
return prompt, image_data, audio_data, modalities, stop if not is_multimodal:
prompt_ids = self.tokenizer_manager.tokenizer.encode(prompt)
return prompt, prompt_ids, image_data, audio_data, modalities, stop
def _build_sampling_params( def _build_sampling_params(
self, self,
...@@ -459,13 +455,9 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -459,13 +455,9 @@ class OpenAIServingChat(OpenAIServingBase):
stream_buffers[index] = stream_buffer + delta stream_buffers[index] = stream_buffer + delta
# Handle reasoning content # Handle reasoning content
enable_thinking = getattr(request, "chat_template_kwargs", {}).get(
"enable_thinking", True
)
if ( if (
self.tokenizer_manager.server_args.reasoning_parser self.tokenizer_manager.server_args.reasoning_parser
and request.separate_reasoning and request.separate_reasoning
and enable_thinking
): ):
reasoning_text, delta = self._process_reasoning_stream( reasoning_text, delta = self._process_reasoning_stream(
index, delta, reasoning_parser_dict, content, request index, delta, reasoning_parser_dict, content, request
...@@ -591,7 +583,7 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -591,7 +583,7 @@ class OpenAIServingChat(OpenAIServingBase):
) )
yield f"data: {usage_chunk.model_dump_json()}\n\n" yield f"data: {usage_chunk.model_dump_json()}\n\n"
except Exception as e: except ValueError as e:
error = self.create_streaming_error_response(str(e)) error = self.create_streaming_error_response(str(e))
yield f"data: {error}\n\n" yield f"data: {error}\n\n"
...@@ -602,7 +594,7 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -602,7 +594,7 @@ class OpenAIServingChat(OpenAIServingBase):
adapted_request: GenerateReqInput, adapted_request: GenerateReqInput,
request: ChatCompletionRequest, request: ChatCompletionRequest,
raw_request: Request, raw_request: Request,
) -> Union[ChatCompletionResponse, ErrorResponse]: ) -> Union[ChatCompletionResponse, ErrorResponse, ORJSONResponse]:
"""Handle non-streaming chat completion request""" """Handle non-streaming chat completion request"""
try: try:
ret = await self.tokenizer_manager.generate_request( ret = await self.tokenizer_manager.generate_request(
...@@ -627,7 +619,7 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -627,7 +619,7 @@ class OpenAIServingChat(OpenAIServingBase):
request: ChatCompletionRequest, request: ChatCompletionRequest,
ret: List[Dict[str, Any]], ret: List[Dict[str, Any]],
created: int, created: int,
) -> ChatCompletionResponse: ) -> Union[ChatCompletionResponse, ORJSONResponse]:
"""Build chat completion response from generation results""" """Build chat completion response from generation results"""
choices = [] choices = []
...@@ -645,11 +637,8 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -645,11 +637,8 @@ class OpenAIServingChat(OpenAIServingBase):
# Handle reasoning content # Handle reasoning content
reasoning_text = None reasoning_text = None
enable_thinking = getattr(request, "chat_template_kwargs", {}).get(
"enable_thinking", True
)
reasoning_parser = self.tokenizer_manager.server_args.reasoning_parser reasoning_parser = self.tokenizer_manager.server_args.reasoning_parser
if reasoning_parser and request.separate_reasoning and enable_thinking: if reasoning_parser and request.separate_reasoning:
try: try:
parser = ReasoningParser( parser = ReasoningParser(
model_type=reasoning_parser, stream_reasoning=False model_type=reasoning_parser, stream_reasoning=False
...@@ -691,9 +680,10 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -691,9 +680,10 @@ class OpenAIServingChat(OpenAIServingBase):
choices.append(choice_data) choices.append(choice_data)
# Calculate usage # Calculate usage
cache_report = self.tokenizer_manager.server_args.enable_cache_report
usage = UsageProcessor.calculate_response_usage( usage = UsageProcessor.calculate_response_usage(
ret, n_choices=request.n, enable_cache_report=cache_report ret,
n_choices=request.n,
enable_cache_report=self.tokenizer_manager.server_args.enable_cache_report,
) )
return ChatCompletionResponse( return ChatCompletionResponse(
...@@ -821,6 +811,25 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -821,6 +811,25 @@ class OpenAIServingChat(OpenAIServingBase):
reasoning_parser = reasoning_parser_dict[index] reasoning_parser = reasoning_parser_dict[index]
return reasoning_parser.parse_stream_chunk(delta) return reasoning_parser.parse_stream_chunk(delta)
def _get_enable_thinking_from_request(request: ChatCompletionRequest) -> bool:
"""Extracts the 'enable_thinking' flag from request chat_template_kwargs.
NOTE: This parameter is only useful for models that support enable_thinking
flag, such as Qwen3.
Args:
request_obj: The request object (or an item from a list of requests).
Returns:
The boolean value of 'enable_thinking' if found and not True, otherwise True.
"""
if (
hasattr(request, "chat_template_kwargs")
and request.chat_template_kwargs
and request.chat_template_kwargs.get("enable_thinking") is not None
):
return request.chat_template_kwargs.get("enable_thinking")
return True
async def _process_tool_call_stream( async def _process_tool_call_stream(
self, self,
index: int, index: int,
......
...@@ -3,12 +3,9 @@ import time ...@@ -3,12 +3,9 @@ import time
from typing import Any, AsyncGenerator, Dict, List, Union from typing import Any, AsyncGenerator, Dict, List, Union
from fastapi import Request from fastapi import Request
from fastapi.responses import StreamingResponse from fastapi.responses import ORJSONResponse, StreamingResponse
from sglang.srt.code_completion_parser import ( from sglang.srt.code_completion_parser import generate_completion_prompt_from_request
generate_completion_prompt_from_request,
is_completion_template_defined,
)
from sglang.srt.entrypoints.openai.protocol import ( from sglang.srt.entrypoints.openai.protocol import (
CompletionRequest, CompletionRequest,
CompletionResponse, CompletionResponse,
...@@ -24,12 +21,22 @@ from sglang.srt.entrypoints.openai.utils import ( ...@@ -24,12 +21,22 @@ from sglang.srt.entrypoints.openai.utils import (
to_openai_style_logprobs, to_openai_style_logprobs,
) )
from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import TokenizerManager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class OpenAIServingCompletion(OpenAIServingBase): class OpenAIServingCompletion(OpenAIServingBase):
"""Handler for completion requests""" """Handler for /v1/completion requests"""
def __init__(
self,
tokenizer_manager: TokenizerManager,
template_manager: TemplateManager,
):
super().__init__(tokenizer_manager)
self.template_manager = template_manager
def _request_id_prefix(self) -> str: def _request_id_prefix(self) -> str:
return "cmpl-" return "cmpl-"
...@@ -47,7 +54,7 @@ class OpenAIServingCompletion(OpenAIServingBase): ...@@ -47,7 +54,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
) )
# Process prompt # Process prompt
prompt = request.prompt prompt = request.prompt
if is_completion_template_defined(): if self.template_manager.completion_template_name is not None:
prompt = generate_completion_prompt_from_request(request) prompt = generate_completion_prompt_from_request(request)
# Set logprob start length based on echo and logprobs # Set logprob start length based on echo and logprobs
...@@ -141,6 +148,7 @@ class OpenAIServingCompletion(OpenAIServingBase): ...@@ -141,6 +148,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
prompt_tokens = {} prompt_tokens = {}
completion_tokens = {} completion_tokens = {}
cached_tokens = {} cached_tokens = {}
hidden_states = {}
try: try:
async for content in self.tokenizer_manager.generate_request( async for content in self.tokenizer_manager.generate_request(
...@@ -152,6 +160,7 @@ class OpenAIServingCompletion(OpenAIServingBase): ...@@ -152,6 +160,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
prompt_tokens[index] = content["meta_info"]["prompt_tokens"] prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
completion_tokens[index] = content["meta_info"]["completion_tokens"] completion_tokens[index] = content["meta_info"]["completion_tokens"]
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0) cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
hidden_states[index] = content["meta_info"].get("hidden_states", None)
stream_buffer = stream_buffers.get(index, "") stream_buffer = stream_buffers.get(index, "")
# Handle echo for first chunk # Handle echo for first chunk
...@@ -192,7 +201,6 @@ class OpenAIServingCompletion(OpenAIServingBase): ...@@ -192,7 +201,6 @@ class OpenAIServingCompletion(OpenAIServingBase):
delta = text[len(stream_buffer) :] delta = text[len(stream_buffer) :]
stream_buffers[index] = stream_buffer + delta stream_buffers[index] = stream_buffer + delta
finish_reason = content["meta_info"]["finish_reason"] finish_reason = content["meta_info"]["finish_reason"]
hidden_states = content["meta_info"].get("hidden_states", None)
choice_data = CompletionResponseStreamChoice( choice_data = CompletionResponseStreamChoice(
index=index, index=index,
...@@ -269,7 +277,7 @@ class OpenAIServingCompletion(OpenAIServingBase): ...@@ -269,7 +277,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
adapted_request: GenerateReqInput, adapted_request: GenerateReqInput,
request: CompletionRequest, request: CompletionRequest,
raw_request: Request, raw_request: Request,
) -> Union[CompletionResponse, ErrorResponse]: ) -> Union[CompletionResponse, ErrorResponse, ORJSONResponse]:
"""Handle non-streaming completion request""" """Handle non-streaming completion request"""
try: try:
generator = self.tokenizer_manager.generate_request( generator = self.tokenizer_manager.generate_request(
......
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from fastapi import Request from fastapi import Request
from fastapi.responses import ORJSONResponse
from sglang.srt.conversation import generate_embedding_convs from sglang.srt.conversation import generate_embedding_convs
from sglang.srt.entrypoints.openai.protocol import ( from sglang.srt.entrypoints.openai.protocol import (
...@@ -13,10 +14,20 @@ from sglang.srt.entrypoints.openai.protocol import ( ...@@ -13,10 +14,20 @@ from sglang.srt.entrypoints.openai.protocol import (
) )
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
from sglang.srt.managers.io_struct import EmbeddingReqInput from sglang.srt.managers.io_struct import EmbeddingReqInput
from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import TokenizerManager
class OpenAIServingEmbedding(OpenAIServingBase): class OpenAIServingEmbedding(OpenAIServingBase):
"""Handler for embedding requests""" """Handler for v1/embeddings requests"""
def __init__(
self,
tokenizer_manager: TokenizerManager,
template_manager: TemplateManager,
):
super().__init__(tokenizer_manager)
self.template_manager = template_manager
def _request_id_prefix(self) -> str: def _request_id_prefix(self) -> str:
return "embd-" return "embd-"
...@@ -68,10 +79,6 @@ class OpenAIServingEmbedding(OpenAIServingBase): ...@@ -68,10 +79,6 @@ class OpenAIServingEmbedding(OpenAIServingBase):
prompt_kwargs = {"text": prompt} prompt_kwargs = {"text": prompt}
elif isinstance(prompt, list): elif isinstance(prompt, list):
if len(prompt) > 0 and isinstance(prompt[0], str): if len(prompt) > 0 and isinstance(prompt[0], str):
# List of strings - if it's a single string in a list, treat as single string
if len(prompt) == 1:
prompt_kwargs = {"text": prompt[0]}
else:
prompt_kwargs = {"text": prompt} prompt_kwargs = {"text": prompt}
elif len(prompt) > 0 and isinstance(prompt[0], MultimodalEmbeddingInput): elif len(prompt) > 0 and isinstance(prompt[0], MultimodalEmbeddingInput):
# Handle multimodal embedding inputs # Handle multimodal embedding inputs
...@@ -84,11 +91,10 @@ class OpenAIServingEmbedding(OpenAIServingBase): ...@@ -84,11 +91,10 @@ class OpenAIServingEmbedding(OpenAIServingBase):
generate_prompts = [] generate_prompts = []
# Check if we have a chat template for multimodal embeddings # Check if we have a chat template for multimodal embeddings
chat_template_name = getattr( if self.template_manager.chat_template_name is not None:
self.tokenizer_manager, "chat_template_name", None convs = generate_embedding_convs(
texts, images, self.template_manager.chat_template_name
) )
if chat_template_name is not None:
convs = generate_embedding_convs(texts, images, chat_template_name)
for conv in convs: for conv in convs:
generate_prompts.append(conv.get_prompt()) generate_prompts.append(conv.get_prompt())
else: else:
...@@ -122,7 +128,7 @@ class OpenAIServingEmbedding(OpenAIServingBase): ...@@ -122,7 +128,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
adapted_request: EmbeddingReqInput, adapted_request: EmbeddingReqInput,
request: EmbeddingRequest, request: EmbeddingRequest,
raw_request: Request, raw_request: Request,
) -> Union[EmbeddingResponse, ErrorResponse]: ) -> Union[EmbeddingResponse, ErrorResponse, ORJSONResponse]:
"""Handle the embedding request""" """Handle the embedding request"""
try: try:
ret = await self.tokenizer_manager.generate_request( ret = await self.tokenizer_manager.generate_request(
......
...@@ -2,6 +2,7 @@ import logging ...@@ -2,6 +2,7 @@ import logging
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from fastapi import Request from fastapi import Request
from fastapi.responses import ORJSONResponse
from sglang.srt.entrypoints.openai.protocol import ( from sglang.srt.entrypoints.openai.protocol import (
ErrorResponse, ErrorResponse,
...@@ -15,7 +16,10 @@ logger = logging.getLogger(__name__) ...@@ -15,7 +16,10 @@ logger = logging.getLogger(__name__)
class OpenAIServingRerank(OpenAIServingBase): class OpenAIServingRerank(OpenAIServingBase):
"""Handler for rerank requests""" """Handler for /v1/rerank requests"""
# NOTE: /v1/rerank is not an official OpenAI endpoint. This module may be moved
# to another module in the future.
def _request_id_prefix(self) -> str: def _request_id_prefix(self) -> str:
return "rerank-" return "rerank-"
...@@ -61,7 +65,7 @@ class OpenAIServingRerank(OpenAIServingBase): ...@@ -61,7 +65,7 @@ class OpenAIServingRerank(OpenAIServingBase):
adapted_request: EmbeddingReqInput, adapted_request: EmbeddingReqInput,
request: V1RerankReqInput, request: V1RerankReqInput,
raw_request: Request, raw_request: Request,
) -> Union[RerankResponse, ErrorResponse]: ) -> Union[List[RerankResponse], ErrorResponse, ORJSONResponse]:
"""Handle the rerank request""" """Handle the rerank request"""
try: try:
ret = await self.tokenizer_manager.generate_request( ret = await self.tokenizer_manager.generate_request(
...@@ -74,16 +78,16 @@ class OpenAIServingRerank(OpenAIServingBase): ...@@ -74,16 +78,16 @@ class OpenAIServingRerank(OpenAIServingBase):
if not isinstance(ret, list): if not isinstance(ret, list):
ret = [ret] ret = [ret]
response = self._build_rerank_response(ret, request) responses = self._build_rerank_response(ret, request)
return response return responses
def _build_rerank_response( def _build_rerank_response(
self, ret: List[Dict[str, Any]], request: V1RerankReqInput self, ret: List[Dict[str, Any]], request: V1RerankReqInput
) -> List[RerankResponse]: ) -> List[RerankResponse]:
"""Build the rerank response from generation results""" """Build the rerank response from generation results"""
response = [] responses = []
for idx, ret_item in enumerate(ret): for idx, ret_item in enumerate(ret):
response.append( responses.append(
RerankResponse( RerankResponse(
score=ret_item["embedding"], score=ret_item["embedding"],
document=request.documents[idx], document=request.documents[idx],
...@@ -93,6 +97,6 @@ class OpenAIServingRerank(OpenAIServingBase): ...@@ -93,6 +97,6 @@ class OpenAIServingRerank(OpenAIServingBase):
) )
# Sort by score in descending order (highest relevance first) # Sort by score in descending order (highest relevance first)
response.sort(key=lambda x: x.score, reverse=True) responses.sort(key=lambda x: x.score, reverse=True)
return response return responses
import logging import logging
from typing import Any, Dict, List, Optional, Union from typing import Union
from fastapi import Request from fastapi import Request
...@@ -14,7 +14,10 @@ logger = logging.getLogger(__name__) ...@@ -14,7 +14,10 @@ logger = logging.getLogger(__name__)
class OpenAIServingScore(OpenAIServingBase): class OpenAIServingScore(OpenAIServingBase):
"""Handler for scoring requests""" """Handler for /v1/score requests"""
# NOTE: /v1/rerank is not an official OpenAI endpoint. This module may be moved
# to another module in the future.
def _request_id_prefix(self) -> str: def _request_id_prefix(self) -> str:
return "score-" return "score-"
......
import logging import logging
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import jinja2.nodes
import transformers.utils.chat_template_utils as hf_chat_utils
from sglang.srt.entrypoints.openai.protocol import ( from sglang.srt.entrypoints.openai.protocol import (
ChatCompletionRequest, ChatCompletionRequest,
CompletionRequest, CompletionRequest,
...@@ -13,168 +10,6 @@ from sglang.srt.entrypoints.openai.protocol import ( ...@@ -13,168 +10,6 @@ from sglang.srt.entrypoints.openai.protocol import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# ============================================================================
# JINJA TEMPLATE CONTENT FORMAT DETECTION
# ============================================================================
#
# This adapts vLLM's approach for detecting chat template content format:
# https://github.com/vllm-project/vllm/blob/02f0c7b220422792f5e53de2a7d51d2d3ff2df28/vllm/entrypoints/chat_utils.py#L296-L313
# - Analyzes Jinja template AST to detect content iteration patterns
# - 'openai' format: templates with {%- for content in message['content'] -%} loops
# - 'string' format: templates that expect simple string content
# - Processes content accordingly to match template expectations
def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
"""Check if node is a variable access like {{ varname }}"""
if isinstance(node, jinja2.nodes.Name):
return node.ctx == "load" and node.name == varname
return False
def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
"""Check if node is an attribute access like {{ varname['key'] }} or {{ varname.key }}"""
if isinstance(node, jinja2.nodes.Getitem):
return (
_is_var_access(node.node, varname)
and isinstance(node.arg, jinja2.nodes.Const)
and node.arg.value == key
)
if isinstance(node, jinja2.nodes.Getattr):
return _is_var_access(node.node, varname) and node.attr == key
return False
def _is_var_or_elems_access(
node: jinja2.nodes.Node,
varname: str,
key: str = None,
) -> bool:
"""Check if node accesses varname or varname[key] with filters/tests"""
if isinstance(node, jinja2.nodes.Filter):
return node.node is not None and _is_var_or_elems_access(
node.node, varname, key
)
if isinstance(node, jinja2.nodes.Test):
return _is_var_or_elems_access(node.node, varname, key)
if isinstance(node, jinja2.nodes.Getitem) and isinstance(
node.arg, jinja2.nodes.Slice
):
return _is_var_or_elems_access(node.node, varname, key)
return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname)
def _try_extract_ast(chat_template: str):
"""Try to parse the Jinja template into an AST"""
try:
jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
return jinja_compiled.environment.parse(chat_template)
except Exception as e:
logger.debug(f"Error when compiling Jinja template: {e}")
return None
def detect_template_content_format(chat_template: str) -> str:
"""
Detect whether a chat template expects 'string' or 'openai' content format.
- 'string': content is a simple string (like DeepSeek templates)
- 'openai': content is a list of structured dicts (like Llama4 templates)
Detection logic:
- If template has loops like {%- for content in message['content'] -%} → 'openai'
- Otherwise → 'string'
"""
jinja_ast = _try_extract_ast(chat_template)
if jinja_ast is None:
return "string"
try:
# Look for patterns like: {%- for content in message['content'] -%}
for loop_ast in jinja_ast.find_all(jinja2.nodes.For):
loop_iter = loop_ast.iter
# Check if iterating over message['content'] or similar
if _is_var_or_elems_access(loop_iter, "message", "content"):
return "openai" # Found content iteration → openai format
return "string" # No content loops found → string format
except Exception as e:
logger.debug(f"Error when parsing AST of Jinja template: {e}")
return "string"
def process_content_for_template_format(
msg_dict: dict,
content_format: str,
image_data: list,
audio_data: list,
modalities: list,
) -> dict:
"""
Process message content based on detected template format.
Args:
msg_dict: Message dictionary with content
content_format: 'string' or 'openai' (detected via AST analysis)
image_data: List to append extracted image URLs
audio_data: List to append extracted audio URLs
modalities: List to append modalities
Returns:
Processed message dictionary
"""
if not isinstance(msg_dict.get("content"), list):
# Already a string or None, no processing needed
return {k: v for k, v in msg_dict.items() if v is not None}
if content_format == "openai":
# OpenAI format: preserve structured content list, normalize types
processed_content_parts = []
for chunk in msg_dict["content"]:
if isinstance(chunk, dict):
chunk_type = chunk.get("type")
if chunk_type == "image_url":
image_data.append(chunk["image_url"]["url"])
if chunk.get("modalities"):
modalities.append(chunk.get("modalities"))
# Normalize to simple 'image' type for template compatibility
processed_content_parts.append({"type": "image"})
elif chunk_type == "audio_url":
audio_data.append(chunk["audio_url"]["url"])
# Normalize to simple 'audio' type
processed_content_parts.append({"type": "audio"})
else:
# Keep other content as-is (text, etc.)
processed_content_parts.append(chunk)
new_msg = {
k: v for k, v in msg_dict.items() if v is not None and k != "content"
}
new_msg["content"] = processed_content_parts
return new_msg
else: # content_format == "string"
# String format: flatten to text only (for templates like DeepSeek)
text_parts = []
for chunk in msg_dict["content"]:
if isinstance(chunk, dict) and chunk.get("type") == "text":
text_parts.append(chunk["text"])
# Note: For string format, we ignore images/audio since the template
# doesn't expect structured content - multimodal placeholders would
# need to be inserted differently
new_msg = msg_dict.copy()
new_msg["content"] = " ".join(text_parts) if text_parts else ""
new_msg = {k: v for k, v in new_msg.items() if v is not None}
return new_msg
def to_openai_style_logprobs( def to_openai_style_logprobs(
input_token_logprobs=None, input_token_logprobs=None,
output_token_logprobs=None, output_token_logprobs=None,
......
...@@ -6,6 +6,7 @@ from typing import Any, Dict, List ...@@ -6,6 +6,7 @@ from typing import Any, Dict, List
from partial_json_parser.core.exceptions import MalformedJSON from partial_json_parser.core.exceptions import MalformedJSON
from partial_json_parser.core.options import Allow from partial_json_parser.core.options import Allow
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.core_types import ( from sglang.srt.function_call.core_types import (
StreamingParseResult, StreamingParseResult,
ToolCallItem, ToolCallItem,
...@@ -16,7 +17,6 @@ from sglang.srt.function_call.utils import ( ...@@ -16,7 +17,6 @@ from sglang.srt.function_call.utils import (
_is_complete_json, _is_complete_json,
_partial_json_loads, _partial_json_loads,
) )
from sglang.srt.openai_api.protocol import Tool
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -3,6 +3,7 @@ import logging ...@@ -3,6 +3,7 @@ import logging
import re import re
from typing import List from typing import List
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import ( from sglang.srt.function_call.core_types import (
StreamingParseResult, StreamingParseResult,
...@@ -12,7 +13,6 @@ from sglang.srt.function_call.core_types import ( ...@@ -12,7 +13,6 @@ from sglang.srt.function_call.core_types import (
) )
from sglang.srt.function_call.ebnf_composer import EBNFComposer from sglang.srt.function_call.ebnf_composer import EBNFComposer
from sglang.srt.function_call.utils import _is_complete_json from sglang.srt.function_call.utils import _is_complete_json
from sglang.srt.openai_api.protocol import Tool
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
import logging import logging
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type, Union from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type, Union
from sglang.srt.entrypoints.openai.protocol import (
StructuralTagResponseFormat,
StructuresResponseFormat,
Tool,
ToolChoice,
)
from sglang.srt.function_call.base_format_detector import BaseFormatDetector from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import ToolCallItem from sglang.srt.function_call.core_types import ToolCallItem
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
...@@ -8,12 +14,6 @@ from sglang.srt.function_call.llama32_detector import Llama32Detector ...@@ -8,12 +14,6 @@ from sglang.srt.function_call.llama32_detector import Llama32Detector
from sglang.srt.function_call.mistral_detector import MistralDetector from sglang.srt.function_call.mistral_detector import MistralDetector
from sglang.srt.function_call.pythonic_detector import PythonicDetector from sglang.srt.function_call.pythonic_detector import PythonicDetector
from sglang.srt.function_call.qwen25_detector import Qwen25Detector from sglang.srt.function_call.qwen25_detector import Qwen25Detector
from sglang.srt.openai_api.protocol import (
StructuralTagResponseFormat,
StructuresResponseFormat,
Tool,
ToolChoice,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment