Unverified Commit 72676cd6 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

feat(oai refactor): Replace `openai_api` with `entrypoints/openai` (#7351)


Co-authored-by: default avatarJin Pan <jpan236@wisc.edu>
parent 02bf31ef
......@@ -20,7 +20,7 @@ from sglang.bench_serving import (
get_gen_prefix_cache_path,
)
from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path
from sglang.srt.openai_api.protocol import ChatCompletionMessageContentPart
from sglang.srt.entrypoints.openai.protocol import ChatCompletionMessageContentPart
from sglang.utils import encode_video_base64
# type of content fields, can be only prompts or with images/videos
......
......@@ -64,11 +64,14 @@
"text = \"Once upon a time\"\n",
"\n",
"curl_text = f\"\"\"curl -s http://localhost:{port}/v1/embeddings \\\n",
" -H \"Content-Type: application/json\" \\\n",
" -d '{{\"model\": \"Alibaba-NLP/gte-Qwen2-1.5B-instruct\", \"input\": \"{text}\"}}'\"\"\"\n",
"\n",
"text_embedding = json.loads(subprocess.check_output(curl_text, shell=True))[\"data\"][0][\n",
" \"embedding\"\n",
"]\n",
"result = subprocess.check_output(curl_text, shell=True)\n",
"\n",
"print(result)\n",
"\n",
"text_embedding = json.loads(result)[\"data\"][0][\"embedding\"]\n",
"\n",
"print_highlight(f\"Text embedding (first 10): {text_embedding[:10]}\")"
]
......@@ -152,6 +155,7 @@
"input_ids = tokenizer.encode(text)\n",
"\n",
"curl_ids = f\"\"\"curl -s http://localhost:{port}/v1/embeddings \\\n",
" -H \"Content-Type: application/json\" \\\n",
" -d '{{\"model\": \"Alibaba-NLP/gte-Qwen2-1.5B-instruct\", \"input\": {json.dumps(input_ids)}}}'\"\"\"\n",
"\n",
"input_ids_embedding = json.loads(subprocess.check_output(curl_ids, shell=True))[\"data\"][\n",
......
......@@ -67,6 +67,7 @@
"\n",
"curl_command = f\"\"\"\n",
"curl -s http://localhost:{port}/v1/chat/completions \\\\\n",
" -H \"Content-Type: application/json\" \\\\\n",
" -d '{{\n",
" \"model\": \"Qwen/Qwen2.5-VL-7B-Instruct\",\n",
" \"messages\": [\n",
......
......@@ -36,7 +36,7 @@
"import requests\n",
"from PIL import Image\n",
"\n",
"from sglang.srt.openai_api.protocol import ChatCompletionRequest\n",
"from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest\n",
"from sglang.srt.conversation import chat_templates\n",
"\n",
"image = Image.open(\n",
......
......@@ -15,9 +15,7 @@
import dataclasses
import json
import logging
import os
from enum import auto
from sglang.srt.entrypoints.openai.protocol import CompletionRequest
......@@ -57,46 +55,6 @@ class CompletionTemplate:
completion_templates: dict[str, CompletionTemplate] = {}
def load_completion_template_for_openai_api(completion_template_arg):
global completion_template_name
logger.info(
f"Use completion template for the OpenAI-compatible API server: {completion_template_arg}"
)
if not completion_template_exists(completion_template_arg):
if not os.path.exists(completion_template_arg):
raise RuntimeError(
f"Completion template {completion_template_arg} is not a built-in template name "
"or a valid completion template file path."
)
assert completion_template_arg.endswith(
".json"
), "unrecognized format of completion template file"
with open(completion_template_arg, "r") as filep:
template = json.load(filep)
try:
fim_position = FimPosition[template["fim_position"]]
except KeyError:
raise ValueError(
f"Unknown fim position: {template['fim_position']}"
) from None
register_completion_template(
CompletionTemplate(
name=template["name"],
fim_begin_token=template["fim_begin_token"],
fim_middle_token=template["fim_middle_token"],
fim_end_token=template["fim_end_token"],
fim_position=fim_position,
),
override=True,
)
completion_template_name = template["name"]
else:
completion_template_name = completion_template_arg
def register_completion_template(template: CompletionTemplate, override: bool = False):
"""Register a new completion template."""
if not override:
......
......@@ -11,7 +11,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Conversation chat templates."""
"""Conversation chat templates.
This module provides conversation template definitions, data structures, and utilities
for managing chat templates across different model types in SGLang.
Key components:
- Conversation class: Defines the structure and behavior of chat templates
- SeparatorStyle enum: Different conversation formatting styles
- Template registry: Functions to register and retrieve templates by name or model path
- Built-in templates: Pre-defined templates for popular models
"""
# Adapted from
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
......@@ -20,7 +30,7 @@ import re
from enum import IntEnum, auto
from typing import Callable, Dict, List, Optional, Tuple, Union
from sglang.srt.openai_api.protocol import ChatCompletionRequest
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
from sglang.srt.utils import read_system_prompt_from_file
......@@ -618,7 +628,7 @@ def generate_chat_conv(
# llama2 template
# reference: https://huggingface.co/blog/codellama#conversational-instructions
# reference: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
# reference: https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/generation.py#L212
register_conv_template(
Conversation(
......
......@@ -37,7 +37,6 @@ setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
import torch
import uvloop
from sglang.srt.code_completion_parser import load_completion_template_for_openai_api
from sglang.srt.entrypoints.EngineBase import EngineBase
from sglang.srt.managers.data_parallel_controller import (
run_data_parallel_controller_process,
......@@ -58,11 +57,8 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromTensorReqInput,
)
from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.openai_api.adapter import (
guess_chat_template_name_from_model_path,
load_chat_template_for_openai_api,
)
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import (
......@@ -123,12 +119,13 @@ class Engine(EngineBase):
logger.info(f"{server_args=}")
# Launch subprocesses
tokenizer_manager, scheduler_info = _launch_subprocesses(
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
server_args=server_args,
port_args=port_args,
)
self.server_args = server_args
self.tokenizer_manager = tokenizer_manager
self.template_manager = template_manager
self.scheduler_info = scheduler_info
context = zmq.Context(2)
......@@ -647,7 +644,7 @@ def _set_envs_and_config(server_args: ServerArgs):
def _launch_subprocesses(
server_args: ServerArgs, port_args: Optional[PortArgs] = None
) -> Tuple[TokenizerManager, Dict]:
) -> Tuple[TokenizerManager, TemplateManager, Dict]:
"""
Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
"""
......@@ -732,7 +729,7 @@ def _launch_subprocesses(
if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0":
# When using `Engine` as a Python API, we don't want to block here.
return None, None
return None, None, None
launch_dummy_health_check_server(server_args.host, server_args.port)
......@@ -741,7 +738,7 @@ def _launch_subprocesses(
logger.error(
f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
)
return None, None
return None, None, None
# Launch detokenizer process
detoken_proc = mp.Process(
......@@ -755,15 +752,15 @@ def _launch_subprocesses(
# Launch tokenizer process
tokenizer_manager = TokenizerManager(server_args, port_args)
if server_args.chat_template:
load_chat_template_for_openai_api(
tokenizer_manager, server_args.chat_template, server_args.model_path
)
else:
guess_chat_template_name_from_model_path(server_args.model_path)
if server_args.completion_template:
load_completion_template_for_openai_api(server_args.completion_template)
# Initialize templates
template_manager = TemplateManager()
template_manager.initialize_templates(
tokenizer_manager=tokenizer_manager,
model_path=server_args.model_path,
chat_template=server_args.chat_template,
completion_template=server_args.completion_template,
)
# Wait for the model to finish loading
scheduler_infos = []
......@@ -787,4 +784,4 @@ def _launch_subprocesses(
# Assume all schedulers have the same scheduler_info
scheduler_info = scheduler_infos[0]
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
return tokenizer_manager, scheduler_info
return tokenizer_manager, template_manager, scheduler_info
......@@ -38,7 +38,8 @@ import orjson
import requests
import uvicorn
import uvloop
from fastapi import FastAPI, File, Form, Request, UploadFile
from fastapi import Depends, FastAPI, Request, UploadFile
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
......@@ -47,6 +48,20 @@ from sglang.srt.disaggregation.utils import (
register_disaggregation_server,
)
from sglang.srt.entrypoints.engine import _launch_subprocesses
from sglang.srt.entrypoints.openai.protocol import (
ChatCompletionRequest,
CompletionRequest,
EmbeddingRequest,
ModelCard,
ModelList,
ScoringRequest,
V1RerankReqInput,
)
from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from sglang.srt.entrypoints.openai.serving_rerank import OpenAIServingRerank
from sglang.srt.entrypoints.openai.serving_score import OpenAIServingScore
from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.managers.io_struct import (
AbortReq,
......@@ -67,26 +82,11 @@ from sglang.srt.managers.io_struct import (
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
V1RerankReqInput,
VertexGenerateReqInput,
)
from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.metrics.func_timer import enable_func_timer
from sglang.srt.openai_api.adapter import (
v1_batches,
v1_cancel_batch,
v1_chat_completions,
v1_completions,
v1_delete_file,
v1_embeddings,
v1_files_create,
v1_rerank,
v1_retrieve_batch,
v1_retrieve_file,
v1_retrieve_file_content,
v1_score,
)
from sglang.srt.openai_api.protocol import ModelCard, ModelList
from sglang.srt.reasoning_parser import ReasoningParser
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
......@@ -109,6 +109,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@dataclasses.dataclass
class _GlobalState:
tokenizer_manager: TokenizerManager
template_manager: TemplateManager
scheduler_info: Dict
......@@ -123,6 +124,24 @@ def set_global_state(global_state: _GlobalState):
@asynccontextmanager
async def lifespan(fast_api_app: FastAPI):
server_args: ServerArgs = fast_api_app.server_args
# Initialize OpenAI serving handlers
fast_api_app.state.openai_serving_completion = OpenAIServingCompletion(
_global_state.tokenizer_manager, _global_state.template_manager
)
fast_api_app.state.openai_serving_chat = OpenAIServingChat(
_global_state.tokenizer_manager, _global_state.template_manager
)
fast_api_app.state.openai_serving_embedding = OpenAIServingEmbedding(
_global_state.tokenizer_manager, _global_state.template_manager
)
fast_api_app.state.openai_serving_score = OpenAIServingScore(
_global_state.tokenizer_manager
)
fast_api_app.state.openai_serving_rerank = OpenAIServingRerank(
_global_state.tokenizer_manager
)
if server_args.warmups is not None:
await execute_warmups(
server_args.warmups.split(","), _global_state.tokenizer_manager
......@@ -148,6 +167,36 @@ app.add_middleware(
allow_headers=["*"],
)
# Custom exception handlers to change validation error status codes
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""Override FastAPI's default 422 validation error with 400"""
return ORJSONResponse(
status_code=400,
content={
"detail": exc.errors(),
"body": exc.body,
},
)
async def validate_json_request(raw_request: Request):
"""Validate that the request content-type is application/json."""
content_type = raw_request.headers.get("content-type", "").lower()
media_type = content_type.split(";", maxsplit=1)[0]
if media_type != "application/json":
raise RequestValidationError(
errors=[
{
"loc": ["header", "content-type"],
"msg": "Unsupported Media Type: Only 'application/json' is allowed",
"type": "value_error",
}
]
)
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
......@@ -330,13 +379,14 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
return _create_error_response(e)
@app.api_route("/v1/rerank", methods=["POST", "PUT"])
async def v1_rerank_request(obj: V1RerankReqInput, raw_request: Request):
try:
ret = await v1_rerank(_global_state.tokenizer_manager, obj, raw_request)
return ret
except ValueError as e:
return _create_error_response(e)
@app.api_route(
"/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
)
async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request):
"""Endpoint for reranking documents based on query relevance."""
return await raw_request.app.state.openai_serving_rerank.handle_request(
request, raw_request
)
@app.api_route("/flush_cache", methods=["GET", "POST"])
......@@ -619,25 +669,39 @@ async def separate_reasoning_request(obj: SeparateReasoningReqInput, request: Re
##### OpenAI-compatible API endpoints #####
@app.post("/v1/completions")
async def openai_v1_completions(raw_request: Request):
return await v1_completions(_global_state.tokenizer_manager, raw_request)
@app.post("/v1/completions", dependencies=[Depends(validate_json_request)])
async def openai_v1_completions(request: CompletionRequest, raw_request: Request):
"""OpenAI-compatible text completion endpoint."""
return await raw_request.app.state.openai_serving_completion.handle_request(
request, raw_request
)
@app.post("/v1/chat/completions")
async def openai_v1_chat_completions(raw_request: Request):
return await v1_chat_completions(_global_state.tokenizer_manager, raw_request)
@app.post("/v1/chat/completions", dependencies=[Depends(validate_json_request)])
async def openai_v1_chat_completions(
request: ChatCompletionRequest, raw_request: Request
):
"""OpenAI-compatible chat completion endpoint."""
return await raw_request.app.state.openai_serving_chat.handle_request(
request, raw_request
)
@app.post("/v1/embeddings", response_class=ORJSONResponse)
async def openai_v1_embeddings(raw_request: Request):
response = await v1_embeddings(_global_state.tokenizer_manager, raw_request)
return response
@app.post(
"/v1/embeddings",
response_class=ORJSONResponse,
dependencies=[Depends(validate_json_request)],
)
async def openai_v1_embeddings(request: EmbeddingRequest, raw_request: Request):
"""OpenAI-compatible embeddings endpoint."""
return await raw_request.app.state.openai_serving_embedding.handle_request(
request, raw_request
)
@app.get("/v1/models", response_class=ORJSONResponse)
def available_models():
"""Show available models."""
async def available_models():
"""Show available models. OpenAI-compatible endpoint."""
served_model_names = [_global_state.tokenizer_manager.served_model_name]
model_cards = []
for served_model_name in served_model_names:
......@@ -651,45 +715,29 @@ def available_models():
return ModelList(data=model_cards)
@app.post("/v1/files")
async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
return await v1_files_create(
file, purpose, _global_state.tokenizer_manager.server_args.file_storage_path
)
@app.delete("/v1/files/{file_id}")
async def delete_file(file_id: str):
# https://platform.openai.com/docs/api-reference/files/delete
return await v1_delete_file(file_id)
@app.post("/v1/batches")
async def openai_v1_batches(raw_request: Request):
return await v1_batches(_global_state.tokenizer_manager, raw_request)
@app.post("/v1/batches/{batch_id}/cancel")
async def cancel_batches(batch_id: str):
# https://platform.openai.com/docs/api-reference/batch/cancel
return await v1_cancel_batch(_global_state.tokenizer_manager, batch_id)
@app.get("/v1/batches/{batch_id}")
async def retrieve_batch(batch_id: str):
return await v1_retrieve_batch(batch_id)
@app.get("/v1/files/{file_id}")
async def retrieve_file(file_id: str):
# https://platform.openai.com/docs/api-reference/files/retrieve
return await v1_retrieve_file(file_id)
@app.get("/v1/models/{model:path}", response_class=ORJSONResponse)
async def retrieve_model(model: str):
"""Retrieves a model instance, providing basic information about the model."""
served_model_names = [_global_state.tokenizer_manager.served_model_name]
if model not in served_model_names:
return ORJSONResponse(
status_code=404,
content={
"error": {
"message": f"The model '{model}' does not exist",
"type": "invalid_request_error",
"param": "model",
"code": "model_not_found",
}
},
)
@app.get("/v1/files/{file_id}/content")
async def retrieve_file_content(file_id: str):
# https://platform.openai.com/docs/api-reference/files/retrieve-contents
return await v1_retrieve_file_content(file_id)
return ModelCard(
id=model,
root=model,
max_model_len=_global_state.tokenizer_manager.model_config.context_len,
)
## SageMaker API
......@@ -700,8 +748,13 @@ async def sagemaker_health() -> Response:
@app.post("/invocations")
async def sagemaker_chat_completions(raw_request: Request):
return await v1_chat_completions(_global_state.tokenizer_manager, raw_request)
async def sagemaker_chat_completions(
request: ChatCompletionRequest, raw_request: Request
):
"""OpenAI-compatible chat completion endpoint."""
return await raw_request.app.state.openai_serving_chat.handle_request(
request, raw_request
)
## Vertex AI API
......@@ -732,10 +785,12 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque
return ORJSONResponse({"predictions": ret})
@app.post("/v1/score")
async def v1_score_request(raw_request: Request):
@app.post("/v1/score", dependencies=[Depends(validate_json_request)])
async def v1_score_request(request: ScoringRequest, raw_request: Request):
"""Endpoint for the decoder-only scoring API. See Engine.score() for detailed documentation."""
return await v1_score(_global_state.tokenizer_manager, raw_request)
return await raw_request.app.state.openai_serving_score.handle_request(
request, raw_request
)
def _create_error_response(e):
......@@ -764,10 +819,13 @@ def launch_server(
1. The HTTP server, Engine, and TokenizerManager both run in the main process.
2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library.
"""
tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args)
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
server_args=server_args
)
set_global_state(
_GlobalState(
tokenizer_manager=tokenizer_manager,
template_manager=template_manager,
scheduler_info=scheduler_info,
)
)
......
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
SGLang OpenAI-Compatible API Server.
This file implements OpenAI-compatible HTTP APIs for the inference engine via FastAPI.
"""
import argparse
import asyncio
import logging
import multiprocessing
import os
import threading
import time
from contextlib import asynccontextmanager
from typing import Callable, Dict, Optional
import numpy as np
import requests
import uvicorn
import uvloop
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response
from sglang.srt.disaggregation.utils import (
FAKE_BOOTSTRAP_HOST,
register_disaggregation_server,
)
from sglang.srt.entrypoints.engine import Engine, _launch_subprocesses
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.metrics.func_timer import enable_func_timer
from sglang.srt.openai_api.protocol import EmbeddingRequest, ModelCard, ModelList
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
add_prometheus_middleware,
delete_directory,
get_bool_env_var,
kill_process_tree,
set_uvicorn_logging_configs,
)
from sglang.srt.warmup import execute_warmups
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
# Store global states
class AppState:
engine: Optional[Engine] = None
server_args: Optional[ServerArgs] = None
tokenizer_manager: Optional[TokenizerManager] = None
scheduler_info: Optional[Dict] = None
embedding_server: Optional[OpenAIServingEmbedding] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
app.state.server_args.enable_metrics = True # By default, we enable metrics
server_args = app.state.server_args
# Initialize engine
logger.info(f"SGLang OpenAI server (PID: {os.getpid()}) is initializing...")
tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args)
app.state.tokenizer_manager = tokenizer_manager
app.state.scheduler_info = scheduler_info
app.state.serving_embedding = OpenAIServingEmbedding(
tokenizer_manager=tokenizer_manager
)
if server_args.enable_metrics:
add_prometheus_middleware(app)
enable_func_timer()
# Initialize engine state attribute to None for now
app.state.engine = None
if server_args.warmups is not None:
await execute_warmups(
server_args.warmups.split(","), app.state.tokenizer_manager
)
logger.info("Warmup ended")
warmup_thread = getattr(app, "warmup_thread", None)
if warmup_thread is not None:
warmup_thread.start()
yield
# Lifespan shutdown
if hasattr(app.state, "engine") and app.state.engine is not None:
logger.info("SGLang engine is shutting down.")
# Add engine cleanup logic here when implemented
# Fast API app with CORS enabled
app = FastAPI(
lifespan=lifespan,
# TODO: check where /openai.json is created or why we use this
openapi_url=None if get_bool_env_var("DISABLE_OPENAPI_DOC") else "/openapi.json",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.api_route("/health", methods=["GET"])
async def health() -> Response:
"""Health check. Used for readiness and liveness probes."""
# In the future, this could check engine health more deeply
# For now, if the server is up, it's healthy.
return Response(status_code=200)
@app.api_route("/v1/models", methods=["GET"])
async def show_models():
"""Show available models. Currently, it returns the served model name.
This endpoint is compatible with the OpenAI API standard.
"""
served_model_names = [app.state.tokenizer_manager.served_model_name]
model_cards = []
for served_model_name in served_model_names:
model_cards.append(
ModelCard(
id=served_model_name,
root=served_model_name,
max_model_len=app.state.tokenizer_manager.model_config.context_len,
)
)
return ModelList(data=model_cards)
@app.get("/get_model_info")
async def get_model_info():
"""Get the model information."""
result = {
"model_path": app.state.tokenizer_manager.model_path,
"tokenizer_path": app.state.tokenizer_manager.server_args.tokenizer_path,
"is_generation": app.state.tokenizer_manager.is_generation,
}
return result
@app.post("/v1/completions")
async def openai_v1_completions(raw_request: Request):
pass
@app.post("/v1/chat/completions")
async def openai_v1_chat_completions(raw_request: Request):
pass
@app.post("/v1/embeddings")
async def openai_v1_embeddings(raw_request: Request):
try:
request_json = await raw_request.json()
request = EmbeddingRequest(**request_json)
except Exception as e:
return app.state.serving_embedding.create_error_response(
f"Invalid request body, error: {str(e)}"
)
ret = await app.state.serving_embedding.handle_request(request, raw_request)
return ret
@app.post("/v1/score")
async def v1_score_request(raw_request: Request):
"""Endpoint for the decoder-only scoring API. See Engine.score() for detailed documentation."""
pass
@app.api_route("/v1/models/{model_id}", methods=["GET"])
async def show_model_detail(model_id: str):
served_model_name = app.state.tokenizer_manager.served_model_name
return ModelCard(
id=served_model_name,
root=served_model_name,
max_model_len=app.state.tokenizer_manager.model_config.context_len,
)
# Additional API endpoints will be implemented in separate serving_*.py modules
# and mounted as APIRouters in future PRs
def _wait_and_warmup(
server_args: ServerArgs,
pipe_finish_writer: Optional[multiprocessing.connection.Connection],
image_token_text: str,
launch_callback: Optional[Callable[[], None]] = None,
):
return
# TODO: Please wait until the /generate implementation is complete,
# or confirm if modifications are needed before removing this.
headers = {}
url = server_args.url()
if server_args.api_key:
headers["Authorization"] = f"Bearer {server_args.api_key}"
# Wait until the server is launched
success = False
for _ in range(120):
time.sleep(1)
try:
res = requests.get(url + "/get_model_info", timeout=5, headers=headers)
assert res.status_code == 200, f"{res=}, {res.text=}"
success = True
break
except (AssertionError, requests.exceptions.RequestException):
last_traceback = get_exception_traceback()
pass
if not success:
if pipe_finish_writer is not None:
pipe_finish_writer.send(last_traceback)
logger.error(f"Initialization failed. warmup error: {last_traceback}")
kill_process_tree(os.getpid())
return
model_info = res.json()
# Send a warmup request
request_name = "/generate" if model_info["is_generation"] else "/encode"
# TODO: Replace with OpenAI API
max_new_tokens = 8 if model_info["is_generation"] else 1
json_data = {
"sampling_params": {
"temperature": 0,
"max_new_tokens": max_new_tokens,
},
}
if server_args.skip_tokenizer_init:
json_data["input_ids"] = [[10, 11, 12] for _ in range(server_args.dp_size)]
# TODO Workaround the bug that embedding errors for list of size 1
if server_args.dp_size == 1:
json_data["input_ids"] = json_data["input_ids"][0]
else:
json_data["text"] = ["The capital city of France is"] * server_args.dp_size
# TODO Workaround the bug that embedding errors for list of size 1
if server_args.dp_size == 1:
json_data["text"] = json_data["text"][0]
# Debug dumping
if server_args.debug_tensor_dump_input_file:
json_data.pop("text", None)
json_data["input_ids"] = np.load(
server_args.debug_tensor_dump_input_file
).tolist()
json_data["sampling_params"]["max_new_tokens"] = 0
try:
if server_args.disaggregation_mode == "null":
res = requests.post(
url + request_name,
json=json_data,
headers=headers,
timeout=600,
)
assert res.status_code == 200, f"{res}"
else:
logger.info(f"Start of prefill warmup ...")
json_data = {
"sampling_params": {
"temperature": 0.0,
"max_new_tokens": 8,
"ignore_eos": True,
},
"bootstrap_host": [FAKE_BOOTSTRAP_HOST] * server_args.dp_size,
# This is a hack to ensure fake transfer is enabled during prefill warmup
# ensure each dp rank has a unique bootstrap_room during prefill warmup
"bootstrap_room": [
i * (2**63 // server_args.dp_size) + (i % server_args.tp_size)
for i in range(server_args.dp_size)
],
"input_ids": [[0, 1, 2, 3]] * server_args.dp_size,
}
res = requests.post(
url + request_name,
json=json_data,
headers=headers,
timeout=1800, # because of deep gemm precache is very long if not precache.
)
logger.info(
f"End of prefill warmup with status {res.status_code}, resp: {res.json()}"
)
except Exception:
last_traceback = get_exception_traceback()
if pipe_finish_writer is not None:
pipe_finish_writer.send(last_traceback)
logger.error(f"Initialization failed. warmup error: {last_traceback}")
kill_process_tree(os.getpid())
return
# Debug print
# logger.info(f"{res.json()=}")
logger.info("The server is fired up and ready to roll!")
if pipe_finish_writer is not None:
pipe_finish_writer.send("ready")
if server_args.delete_ckpt_after_loading:
delete_directory(server_args.model_path)
if server_args.debug_tensor_dump_input_file:
kill_process_tree(os.getpid())
if server_args.pdlb_url is not None:
register_disaggregation_server(
server_args.disaggregation_mode,
server_args.port,
server_args.disaggregation_bootstrap_port,
server_args.pdlb_url,
)
if launch_callback is not None:
launch_callback()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="SGLang OpenAI-Compatible API Server")
# Add arguments from ServerArgs. This allows reuse of existing CLI definitions.
ServerArgs.add_cli_args(parser)
# Potentially add server-specific arguments here in the future if needed
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
# Store server_args in app.state for access in lifespan and endpoints
app.state.server_args = server_args
# Configure logging
logging.basicConfig(
level=server_args.log_level.upper(),
format="%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s",
)
# Send a warmup request - we will create the thread launch it
# in the lifespan after all other warmups have fired.
warmup_thread = threading.Thread(
target=_wait_and_warmup,
args=(
server_args,
None,
None, # Never used
None,
),
)
app.warmup_thread = warmup_thread
try:
# Start the server
set_uvicorn_logging_configs()
uvicorn.run(
app,
host=server_args.host,
port=server_args.port,
log_level=server_args.log_level.lower(),
timeout_keep_alive=60, # Increased keep-alive for potentially long requests
loop="uvloop", # Use uvloop for better performance if available
)
finally:
warmup_thread.join()
......@@ -207,7 +207,7 @@ class CompletionResponseChoice(BaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Literal["stop", "length", "content_filter", "abort"]
finish_reason: Optional[Literal["stop", "length", "content_filter", "abort"]] = None
matched_stop: Union[None, int, str] = None
hidden_states: Optional[object] = None
......@@ -404,21 +404,13 @@ class ChatCompletionRequest(BaseModel):
@model_validator(mode="before")
@classmethod
def set_tool_choice_default(cls, values):
if isinstance(values, dict):
if values.get("tool_choice") is None:
if values.get("tools") is None:
values["tool_choice"] = "none"
else:
values["tool_choice"] = "auto"
if values.get("tool_choice") is None:
if values.get("tools") is None:
values["tool_choice"] = "none"
else:
values["tool_choice"] = "auto"
return values
@field_validator("messages")
@classmethod
def validate_messages_not_empty(cls, v):
if not v:
raise ValueError("Messages cannot be empty")
return v
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
top_k: int = -1
min_p: float = 0.0
......@@ -457,9 +449,11 @@ class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
finish_reason: Literal[
"stop", "length", "tool_calls", "content_filter", "function_call", "abort"
]
finish_reason: Optional[
Literal[
"stop", "length", "tool_calls", "content_filter", "function_call", "abort"
]
] = None
matched_stop: Union[None, int, str] = None
hidden_states: Optional[object] = None
......@@ -530,7 +524,7 @@ class EmbeddingRequest(BaseModel):
input: EmbeddingInput
model: str
encoding_format: str = "float"
dimensions: int = None
dimensions: Optional[int] = None
user: Optional[str] = None
# The request id.
......
......@@ -2,16 +2,12 @@ import json
import logging
import uuid
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Union
from typing import Any, Optional, Union
from fastapi import Request
from fastapi.responses import ORJSONResponse, StreamingResponse
from sglang.srt.entrypoints.openai.protocol import (
ErrorResponse,
OpenAIServingRequest,
UsageInfo,
)
from sglang.srt.entrypoints.openai.protocol import ErrorResponse, OpenAIServingRequest
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.tokenizer_manager import TokenizerManager
......@@ -51,7 +47,7 @@ class OpenAIServingBase(ABC):
)
except Exception as e:
logger.error(f"Error in request: {e}")
logger.exception(f"Error in request: {e}")
return self.create_error_response(
message=f"Internal server error: {str(e)}",
err_type="InternalServerError",
......@@ -63,8 +59,12 @@ class OpenAIServingBase(ABC):
"""Generate request ID based on request type"""
pass
def _generate_request_id_base(self, request: OpenAIServingRequest) -> str:
def _generate_request_id_base(self, request: OpenAIServingRequest) -> Optional[str]:
"""Generate request ID based on request type"""
return None
# TODO(chang): the rid is used in io_strcut check and often violates `The rid should be a list` AssertionError
# Temporarily return None in this function until the rid logic is clear.
if rid := getattr(request, "rid", None):
return rid
......@@ -83,7 +83,7 @@ class OpenAIServingBase(ABC):
adapted_request: GenerateReqInput,
request: OpenAIServingRequest,
raw_request: Request,
) -> StreamingResponse:
) -> Union[StreamingResponse, ErrorResponse, ORJSONResponse]:
"""Handle streaming request
Override this method in child classes that support streaming requests.
......@@ -99,7 +99,7 @@ class OpenAIServingBase(ABC):
adapted_request: GenerateReqInput,
request: OpenAIServingRequest,
raw_request: Request,
) -> Union[Any, ErrorResponse]:
) -> Union[Any, ErrorResponse, ORJSONResponse]:
"""Handle non-streaming request
Override this method in child classes that support non-streaming requests.
......@@ -110,7 +110,7 @@ class OpenAIServingBase(ABC):
status_code=501,
)
def _validate_request(self, request: OpenAIServingRequest) -> Optional[str]:
def _validate_request(self, _: OpenAIServingRequest) -> Optional[str]:
"""Validate request"""
pass
......@@ -122,6 +122,7 @@ class OpenAIServingBase(ABC):
param: Optional[str] = None,
) -> ORJSONResponse:
"""Create an error response"""
# TODO: remove fastapi dependency in openai and move response handling to the entrypoint
error = ErrorResponse(
object="error",
message=message,
......
import base64
import json
import logging
import time
......@@ -6,7 +5,7 @@ import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from fastapi import Request
from fastapi.responses import StreamingResponse
from fastapi.responses import ORJSONResponse, StreamingResponse
from sglang.srt.conversation import generate_chat_conv
from sglang.srt.entrypoints.openai.protocol import (
......@@ -28,13 +27,14 @@ from sglang.srt.entrypoints.openai.protocol import (
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor
from sglang.srt.entrypoints.openai.utils import (
detect_template_content_format,
process_content_for_template_format,
process_hidden_states_from_ret,
to_openai_style_logprobs,
)
from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.jinja_template_utils import process_content_for_template_format
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.reasoning_parser import ReasoningParser
from sglang.utils import convert_json_schema_to_str
......@@ -42,13 +42,13 @@ logger = logging.getLogger(__name__)
class OpenAIServingChat(OpenAIServingBase):
"""Handler for chat completion requests"""
"""Handler for /v1/chat/completions requests"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Instance-specific cache for template content format detection
self._cached_chat_template = None
self._cached_template_format = None
def __init__(
self, tokenizer_manager: TokenizerManager, template_manager: TemplateManager
):
super().__init__(tokenizer_manager)
self.template_manager = template_manager
def _request_id_prefix(self) -> str:
return "chatcmpl-"
......@@ -142,19 +142,14 @@ class OpenAIServingChat(OpenAIServingBase):
)
# Use chat template
if (
hasattr(self.tokenizer_manager, "chat_template_name")
and self.tokenizer_manager.chat_template_name is None
):
if self.template_manager.chat_template_name is None:
prompt, prompt_ids, image_data, audio_data, modalities, stop = (
self._apply_jinja_template(request, tools, is_multimodal)
)
else:
prompt, image_data, audio_data, modalities, stop = (
self._apply_conversation_template(request)
prompt, prompt_ids, image_data, audio_data, modalities, stop = (
self._apply_conversation_template(request, is_multimodal)
)
if not is_multimodal:
prompt_ids = self.tokenizer_manager.tokenizer.encode(prompt)
else:
# Use raw prompt
prompt_ids = request.messages
......@@ -181,23 +176,14 @@ class OpenAIServingChat(OpenAIServingBase):
is_multimodal: bool,
) -> tuple[str, List[int], Optional[Any], Optional[Any], List[str], List[str]]:
"""Apply Jinja chat template"""
prompt = ""
prompt_ids = []
openai_compatible_messages = []
image_data = []
audio_data = []
modalities = []
# Detect template content format
current_template = self.tokenizer_manager.tokenizer.chat_template
if current_template != self._cached_chat_template:
self._cached_chat_template = current_template
self._cached_template_format = detect_template_content_format(
current_template
)
logger.info(
f"Detected chat template content format: {self._cached_template_format}"
)
template_content_format = self._cached_template_format
template_content_format = self.template_manager.jinja_template_content_format
for message in request.messages:
if message.content is None:
......@@ -262,14 +248,21 @@ class OpenAIServingChat(OpenAIServingBase):
if is_multimodal:
prompt = self.tokenizer_manager.tokenizer.decode(prompt_ids)
stop = request.stop or []
stop = request.stop
image_data = image_data if image_data else None
audio_data = audio_data if audio_data else None
modalities = modalities if modalities else []
return prompt, prompt_ids, image_data, audio_data, modalities, stop
def _apply_conversation_template(
self, request: ChatCompletionRequest
) -> tuple[str, Optional[Any], Optional[Any], List[str], List[str]]:
self,
request: ChatCompletionRequest,
is_multimodal: bool,
) -> tuple[str, Optional[Any], Optional[Any], List[str], List[str], List[str]]:
"""Apply conversation template"""
conv = generate_chat_conv(request, self.tokenizer_manager.chat_template_name)
prompt = ""
prompt_ids = []
conv = generate_chat_conv(request, self.template_manager.chat_template_name)
# If we should continue the final assistant message, adjust the conversation.
if (
......@@ -296,9 +289,9 @@ class OpenAIServingChat(OpenAIServingBase):
else:
prompt = conv.get_prompt()
image_data = conv.image_data
audio_data = conv.audio_data
modalities = conv.modalities
image_data = conv.image_data if conv.image_data else None
audio_data = conv.audio_data if conv.audio_data else None
modalities = conv.modalities if conv.modalities else []
stop = conv.stop_str or [] if not request.ignore_eos else []
if request.stop:
......@@ -307,7 +300,10 @@ class OpenAIServingChat(OpenAIServingBase):
else:
stop.extend(request.stop)
return prompt, image_data, audio_data, modalities, stop
if not is_multimodal:
prompt_ids = self.tokenizer_manager.tokenizer.encode(prompt)
return prompt, prompt_ids, image_data, audio_data, modalities, stop
def _build_sampling_params(
self,
......@@ -459,13 +455,9 @@ class OpenAIServingChat(OpenAIServingBase):
stream_buffers[index] = stream_buffer + delta
# Handle reasoning content
enable_thinking = getattr(request, "chat_template_kwargs", {}).get(
"enable_thinking", True
)
if (
self.tokenizer_manager.server_args.reasoning_parser
and request.separate_reasoning
and enable_thinking
):
reasoning_text, delta = self._process_reasoning_stream(
index, delta, reasoning_parser_dict, content, request
......@@ -591,7 +583,7 @@ class OpenAIServingChat(OpenAIServingBase):
)
yield f"data: {usage_chunk.model_dump_json()}\n\n"
except Exception as e:
except ValueError as e:
error = self.create_streaming_error_response(str(e))
yield f"data: {error}\n\n"
......@@ -602,7 +594,7 @@ class OpenAIServingChat(OpenAIServingBase):
adapted_request: GenerateReqInput,
request: ChatCompletionRequest,
raw_request: Request,
) -> Union[ChatCompletionResponse, ErrorResponse]:
) -> Union[ChatCompletionResponse, ErrorResponse, ORJSONResponse]:
"""Handle non-streaming chat completion request"""
try:
ret = await self.tokenizer_manager.generate_request(
......@@ -627,7 +619,7 @@ class OpenAIServingChat(OpenAIServingBase):
request: ChatCompletionRequest,
ret: List[Dict[str, Any]],
created: int,
) -> ChatCompletionResponse:
) -> Union[ChatCompletionResponse, ORJSONResponse]:
"""Build chat completion response from generation results"""
choices = []
......@@ -645,11 +637,8 @@ class OpenAIServingChat(OpenAIServingBase):
# Handle reasoning content
reasoning_text = None
enable_thinking = getattr(request, "chat_template_kwargs", {}).get(
"enable_thinking", True
)
reasoning_parser = self.tokenizer_manager.server_args.reasoning_parser
if reasoning_parser and request.separate_reasoning and enable_thinking:
if reasoning_parser and request.separate_reasoning:
try:
parser = ReasoningParser(
model_type=reasoning_parser, stream_reasoning=False
......@@ -691,9 +680,10 @@ class OpenAIServingChat(OpenAIServingBase):
choices.append(choice_data)
# Calculate usage
cache_report = self.tokenizer_manager.server_args.enable_cache_report
usage = UsageProcessor.calculate_response_usage(
ret, n_choices=request.n, enable_cache_report=cache_report
ret,
n_choices=request.n,
enable_cache_report=self.tokenizer_manager.server_args.enable_cache_report,
)
return ChatCompletionResponse(
......@@ -821,6 +811,25 @@ class OpenAIServingChat(OpenAIServingBase):
reasoning_parser = reasoning_parser_dict[index]
return reasoning_parser.parse_stream_chunk(delta)
def _get_enable_thinking_from_request(request: ChatCompletionRequest) -> bool:
"""Extracts the 'enable_thinking' flag from request chat_template_kwargs.
NOTE: This parameter is only useful for models that support enable_thinking
flag, such as Qwen3.
Args:
request_obj: The request object (or an item from a list of requests).
Returns:
The boolean value of 'enable_thinking' if found and not True, otherwise True.
"""
if (
hasattr(request, "chat_template_kwargs")
and request.chat_template_kwargs
and request.chat_template_kwargs.get("enable_thinking") is not None
):
return request.chat_template_kwargs.get("enable_thinking")
return True
async def _process_tool_call_stream(
self,
index: int,
......
......@@ -3,12 +3,9 @@ import time
from typing import Any, AsyncGenerator, Dict, List, Union
from fastapi import Request
from fastapi.responses import StreamingResponse
from fastapi.responses import ORJSONResponse, StreamingResponse
from sglang.srt.code_completion_parser import (
generate_completion_prompt_from_request,
is_completion_template_defined,
)
from sglang.srt.code_completion_parser import generate_completion_prompt_from_request
from sglang.srt.entrypoints.openai.protocol import (
CompletionRequest,
CompletionResponse,
......@@ -24,12 +21,22 @@ from sglang.srt.entrypoints.openai.utils import (
to_openai_style_logprobs,
)
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import TokenizerManager
logger = logging.getLogger(__name__)
class OpenAIServingCompletion(OpenAIServingBase):
"""Handler for completion requests"""
"""Handler for /v1/completion requests"""
def __init__(
self,
tokenizer_manager: TokenizerManager,
template_manager: TemplateManager,
):
super().__init__(tokenizer_manager)
self.template_manager = template_manager
def _request_id_prefix(self) -> str:
return "cmpl-"
......@@ -47,7 +54,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
)
# Process prompt
prompt = request.prompt
if is_completion_template_defined():
if self.template_manager.completion_template_name is not None:
prompt = generate_completion_prompt_from_request(request)
# Set logprob start length based on echo and logprobs
......@@ -141,6 +148,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
prompt_tokens = {}
completion_tokens = {}
cached_tokens = {}
hidden_states = {}
try:
async for content in self.tokenizer_manager.generate_request(
......@@ -152,6 +160,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
completion_tokens[index] = content["meta_info"]["completion_tokens"]
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
hidden_states[index] = content["meta_info"].get("hidden_states", None)
stream_buffer = stream_buffers.get(index, "")
# Handle echo for first chunk
......@@ -192,7 +201,6 @@ class OpenAIServingCompletion(OpenAIServingBase):
delta = text[len(stream_buffer) :]
stream_buffers[index] = stream_buffer + delta
finish_reason = content["meta_info"]["finish_reason"]
hidden_states = content["meta_info"].get("hidden_states", None)
choice_data = CompletionResponseStreamChoice(
index=index,
......@@ -269,7 +277,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
adapted_request: GenerateReqInput,
request: CompletionRequest,
raw_request: Request,
) -> Union[CompletionResponse, ErrorResponse]:
) -> Union[CompletionResponse, ErrorResponse, ORJSONResponse]:
"""Handle non-streaming completion request"""
try:
generator = self.tokenizer_manager.generate_request(
......
from typing import Any, Dict, List, Optional, Union
from fastapi import Request
from fastapi.responses import ORJSONResponse
from sglang.srt.conversation import generate_embedding_convs
from sglang.srt.entrypoints.openai.protocol import (
......@@ -13,10 +14,20 @@ from sglang.srt.entrypoints.openai.protocol import (
)
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
from sglang.srt.managers.io_struct import EmbeddingReqInput
from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import TokenizerManager
class OpenAIServingEmbedding(OpenAIServingBase):
"""Handler for embedding requests"""
"""Handler for v1/embeddings requests"""
def __init__(
self,
tokenizer_manager: TokenizerManager,
template_manager: TemplateManager,
):
super().__init__(tokenizer_manager)
self.template_manager = template_manager
def _request_id_prefix(self) -> str:
return "embd-"
......@@ -68,11 +79,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
prompt_kwargs = {"text": prompt}
elif isinstance(prompt, list):
if len(prompt) > 0 and isinstance(prompt[0], str):
# List of strings - if it's a single string in a list, treat as single string
if len(prompt) == 1:
prompt_kwargs = {"text": prompt[0]}
else:
prompt_kwargs = {"text": prompt}
prompt_kwargs = {"text": prompt}
elif len(prompt) > 0 and isinstance(prompt[0], MultimodalEmbeddingInput):
# Handle multimodal embedding inputs
texts = []
......@@ -84,11 +91,10 @@ class OpenAIServingEmbedding(OpenAIServingBase):
generate_prompts = []
# Check if we have a chat template for multimodal embeddings
chat_template_name = getattr(
self.tokenizer_manager, "chat_template_name", None
)
if chat_template_name is not None:
convs = generate_embedding_convs(texts, images, chat_template_name)
if self.template_manager.chat_template_name is not None:
convs = generate_embedding_convs(
texts, images, self.template_manager.chat_template_name
)
for conv in convs:
generate_prompts.append(conv.get_prompt())
else:
......@@ -122,7 +128,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
adapted_request: EmbeddingReqInput,
request: EmbeddingRequest,
raw_request: Request,
) -> Union[EmbeddingResponse, ErrorResponse]:
) -> Union[EmbeddingResponse, ErrorResponse, ORJSONResponse]:
"""Handle the embedding request"""
try:
ret = await self.tokenizer_manager.generate_request(
......
......@@ -2,6 +2,7 @@ import logging
from typing import Any, Dict, List, Optional, Union
from fastapi import Request
from fastapi.responses import ORJSONResponse
from sglang.srt.entrypoints.openai.protocol import (
ErrorResponse,
......@@ -15,7 +16,10 @@ logger = logging.getLogger(__name__)
class OpenAIServingRerank(OpenAIServingBase):
"""Handler for rerank requests"""
"""Handler for /v1/rerank requests"""
# NOTE: /v1/rerank is not an official OpenAI endpoint. This module may be moved
# to another module in the future.
def _request_id_prefix(self) -> str:
return "rerank-"
......@@ -61,7 +65,7 @@ class OpenAIServingRerank(OpenAIServingBase):
adapted_request: EmbeddingReqInput,
request: V1RerankReqInput,
raw_request: Request,
) -> Union[RerankResponse, ErrorResponse]:
) -> Union[List[RerankResponse], ErrorResponse, ORJSONResponse]:
"""Handle the rerank request"""
try:
ret = await self.tokenizer_manager.generate_request(
......@@ -74,16 +78,16 @@ class OpenAIServingRerank(OpenAIServingBase):
if not isinstance(ret, list):
ret = [ret]
response = self._build_rerank_response(ret, request)
return response
responses = self._build_rerank_response(ret, request)
return responses
def _build_rerank_response(
self, ret: List[Dict[str, Any]], request: V1RerankReqInput
) -> List[RerankResponse]:
"""Build the rerank response from generation results"""
response = []
responses = []
for idx, ret_item in enumerate(ret):
response.append(
responses.append(
RerankResponse(
score=ret_item["embedding"],
document=request.documents[idx],
......@@ -93,6 +97,6 @@ class OpenAIServingRerank(OpenAIServingBase):
)
# Sort by score in descending order (highest relevance first)
response.sort(key=lambda x: x.score, reverse=True)
responses.sort(key=lambda x: x.score, reverse=True)
return response
return responses
import logging
from typing import Any, Dict, List, Optional, Union
from typing import Union
from fastapi import Request
......@@ -14,7 +14,10 @@ logger = logging.getLogger(__name__)
class OpenAIServingScore(OpenAIServingBase):
"""Handler for scoring requests"""
"""Handler for /v1/score requests"""
# NOTE: /v1/rerank is not an official OpenAI endpoint. This module may be moved
# to another module in the future.
def _request_id_prefix(self) -> str:
return "score-"
......
import logging
from typing import Any, Dict, List, Optional, Union
import jinja2.nodes
import transformers.utils.chat_template_utils as hf_chat_utils
from sglang.srt.entrypoints.openai.protocol import (
ChatCompletionRequest,
CompletionRequest,
......@@ -13,168 +10,6 @@ from sglang.srt.entrypoints.openai.protocol import (
logger = logging.getLogger(__name__)
# ============================================================================
# JINJA TEMPLATE CONTENT FORMAT DETECTION
# ============================================================================
#
# This adapts vLLM's approach for detecting chat template content format:
# https://github.com/vllm-project/vllm/blob/02f0c7b220422792f5e53de2a7d51d2d3ff2df28/vllm/entrypoints/chat_utils.py#L296-L313
# - Analyzes Jinja template AST to detect content iteration patterns
# - 'openai' format: templates with {%- for content in message['content'] -%} loops
# - 'string' format: templates that expect simple string content
# - Processes content accordingly to match template expectations
def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
"""Check if node is a variable access like {{ varname }}"""
if isinstance(node, jinja2.nodes.Name):
return node.ctx == "load" and node.name == varname
return False
def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
"""Check if node is an attribute access like {{ varname['key'] }} or {{ varname.key }}"""
if isinstance(node, jinja2.nodes.Getitem):
return (
_is_var_access(node.node, varname)
and isinstance(node.arg, jinja2.nodes.Const)
and node.arg.value == key
)
if isinstance(node, jinja2.nodes.Getattr):
return _is_var_access(node.node, varname) and node.attr == key
return False
def _is_var_or_elems_access(
node: jinja2.nodes.Node,
varname: str,
key: str = None,
) -> bool:
"""Check if node accesses varname or varname[key] with filters/tests"""
if isinstance(node, jinja2.nodes.Filter):
return node.node is not None and _is_var_or_elems_access(
node.node, varname, key
)
if isinstance(node, jinja2.nodes.Test):
return _is_var_or_elems_access(node.node, varname, key)
if isinstance(node, jinja2.nodes.Getitem) and isinstance(
node.arg, jinja2.nodes.Slice
):
return _is_var_or_elems_access(node.node, varname, key)
return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname)
def _try_extract_ast(chat_template: str):
"""Try to parse the Jinja template into an AST"""
try:
jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
return jinja_compiled.environment.parse(chat_template)
except Exception as e:
logger.debug(f"Error when compiling Jinja template: {e}")
return None
def detect_template_content_format(chat_template: str) -> str:
"""
Detect whether a chat template expects 'string' or 'openai' content format.
- 'string': content is a simple string (like DeepSeek templates)
- 'openai': content is a list of structured dicts (like Llama4 templates)
Detection logic:
- If template has loops like {%- for content in message['content'] -%} → 'openai'
- Otherwise → 'string'
"""
jinja_ast = _try_extract_ast(chat_template)
if jinja_ast is None:
return "string"
try:
# Look for patterns like: {%- for content in message['content'] -%}
for loop_ast in jinja_ast.find_all(jinja2.nodes.For):
loop_iter = loop_ast.iter
# Check if iterating over message['content'] or similar
if _is_var_or_elems_access(loop_iter, "message", "content"):
return "openai" # Found content iteration → openai format
return "string" # No content loops found → string format
except Exception as e:
logger.debug(f"Error when parsing AST of Jinja template: {e}")
return "string"
def process_content_for_template_format(
msg_dict: dict,
content_format: str,
image_data: list,
audio_data: list,
modalities: list,
) -> dict:
"""
Process message content based on detected template format.
Args:
msg_dict: Message dictionary with content
content_format: 'string' or 'openai' (detected via AST analysis)
image_data: List to append extracted image URLs
audio_data: List to append extracted audio URLs
modalities: List to append modalities
Returns:
Processed message dictionary
"""
if not isinstance(msg_dict.get("content"), list):
# Already a string or None, no processing needed
return {k: v for k, v in msg_dict.items() if v is not None}
if content_format == "openai":
# OpenAI format: preserve structured content list, normalize types
processed_content_parts = []
for chunk in msg_dict["content"]:
if isinstance(chunk, dict):
chunk_type = chunk.get("type")
if chunk_type == "image_url":
image_data.append(chunk["image_url"]["url"])
if chunk.get("modalities"):
modalities.append(chunk.get("modalities"))
# Normalize to simple 'image' type for template compatibility
processed_content_parts.append({"type": "image"})
elif chunk_type == "audio_url":
audio_data.append(chunk["audio_url"]["url"])
# Normalize to simple 'audio' type
processed_content_parts.append({"type": "audio"})
else:
# Keep other content as-is (text, etc.)
processed_content_parts.append(chunk)
new_msg = {
k: v for k, v in msg_dict.items() if v is not None and k != "content"
}
new_msg["content"] = processed_content_parts
return new_msg
else: # content_format == "string"
# String format: flatten to text only (for templates like DeepSeek)
text_parts = []
for chunk in msg_dict["content"]:
if isinstance(chunk, dict) and chunk.get("type") == "text":
text_parts.append(chunk["text"])
# Note: For string format, we ignore images/audio since the template
# doesn't expect structured content - multimodal placeholders would
# need to be inserted differently
new_msg = msg_dict.copy()
new_msg["content"] = " ".join(text_parts) if text_parts else ""
new_msg = {k: v for k, v in new_msg.items() if v is not None}
return new_msg
def to_openai_style_logprobs(
input_token_logprobs=None,
output_token_logprobs=None,
......
......@@ -6,6 +6,7 @@ from typing import Any, Dict, List
from partial_json_parser.core.exceptions import MalformedJSON
from partial_json_parser.core.options import Allow
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.core_types import (
StreamingParseResult,
ToolCallItem,
......@@ -16,7 +17,6 @@ from sglang.srt.function_call.utils import (
_is_complete_json,
_partial_json_loads,
)
from sglang.srt.openai_api.protocol import Tool
logger = logging.getLogger(__name__)
......
......@@ -3,6 +3,7 @@ import logging
import re
from typing import List
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
......@@ -12,7 +13,6 @@ from sglang.srt.function_call.core_types import (
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer
from sglang.srt.function_call.utils import _is_complete_json
from sglang.srt.openai_api.protocol import Tool
logger = logging.getLogger(__name__)
......
import logging
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type, Union
from sglang.srt.entrypoints.openai.protocol import (
StructuralTagResponseFormat,
StructuresResponseFormat,
Tool,
ToolChoice,
)
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import ToolCallItem
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
......@@ -8,12 +14,6 @@ from sglang.srt.function_call.llama32_detector import Llama32Detector
from sglang.srt.function_call.mistral_detector import MistralDetector
from sglang.srt.function_call.pythonic_detector import PythonicDetector
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
from sglang.srt.openai_api.protocol import (
StructuralTagResponseFormat,
StructuresResponseFormat,
Tool,
ToolChoice,
)
logger = logging.getLogger(__name__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment