Unverified Commit 3fc97f67 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Move openai api server into a separate file (#429)

parent abc548c7
import asyncio import asyncio
import concurrent.futures import concurrent.futures
import dataclasses import dataclasses
import logging
import multiprocessing as mp import multiprocessing as mp
import os import os
from typing import List from typing import List
...@@ -31,6 +32,8 @@ from sglang.srt.utils import get_exception_traceback, is_multimodal_model, load_ ...@@ -31,6 +32,8 @@ from sglang.srt.utils import get_exception_traceback, is_multimodal_model, load_
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
logger = logging.getLogger(__name__)
@dataclasses.dataclass @dataclasses.dataclass
class ReqState: class ReqState:
...@@ -185,10 +188,15 @@ class TokenizerManager: ...@@ -185,10 +188,15 @@ class TokenizerManager:
while True: while True:
await event.wait() await event.wait()
yield self.convert_logprob_style(state.out_list[-1], out = self.convert_logprob_style(state.out_list[-1],
obj.return_logprob, obj.return_logprob,
obj.top_logprobs_num, obj.top_logprobs_num,
obj.return_text_in_logprobs) obj.return_text_in_logprobs)
if self.server_args.log_requests and state.finished:
logger.info(f"in={obj.text}, out={out}")
yield out
state.out_list = [] state.out_list = []
if state.finished: if state.finished:
del self.rid_to_state[rid] del self.rid_to_state[rid]
......
"""Conversion between OpenAI APIs and native SRT APIs"""
import json
import os
from fastapi import HTTPException, Request
from fastapi.responses import StreamingResponse
from sglang.srt.conversation import (
Conversation,
SeparatorStyle,
chat_template_exists,
generate_chat_conv,
register_conv_template,
)
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.openai_protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatMessage,
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
DeltaMessage,
LogProbs,
UsageInfo,
)
from sglang.srt.utils import jsonify_pydantic_model
chat_template_name = None
def load_chat_template_for_openai_api(chat_template_arg):
global chat_template_name
print(f"Use chat template: {chat_template_arg}")
if not chat_template_exists(chat_template_arg):
if not os.path.exists(chat_template_arg):
raise RuntimeError(
f"Chat template {chat_template_arg} is not a built-in template name "
"or a valid chat template file path."
)
with open(chat_template_arg, "r") as filep:
template = json.load(filep)
try:
sep_style = SeparatorStyle[template["sep_style"]]
except KeyError:
raise ValueError(
f"Unknown separator style: {template['sep_style']}"
) from None
register_conv_template(
Conversation(
name=template["name"],
system_template=template["system"] + "\n{system_message}",
system_message=template.get("system_message", ""),
roles=(template["user"], template["assistant"]),
sep_style=sep_style,
sep=template.get("sep", "\n"),
stop_str=template["stop_str"],
),
override=True,
)
chat_template_name = template["name"]
else:
chat_template_name = chat_template_arg
async def v1_completions(tokenizer_manager, raw_request: Request):
request_json = await raw_request.json()
request = CompletionRequest(**request_json)
# TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
assert request.n == 1
adapted_request = GenerateReqInput(
text=request.prompt,
sampling_params={
"temperature": request.temperature,
"max_new_tokens": request.max_tokens,
"stop": request.stop,
"top_p": request.top_p,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
"regex": request.regex,
},
return_logprob=request.logprobs is not None and request.logprobs > 0,
top_logprobs_num=request.logprobs if request.logprobs is not None else 0,
return_text_in_logprobs=True,
stream=request.stream,
)
adapted_request.post_init()
if adapted_request.stream:
async def generate_stream_resp():
stream_buffer = ""
n_prev_token = 0
async for content in tokenizer_manager.generate_request(adapted_request):
text = content["text"]
prompt_tokens = content["meta_info"]["prompt_tokens"]
completion_tokens = content["meta_info"]["completion_tokens"]
if not stream_buffer: # The first chunk
if request.echo:
# Prepend prompt in response text.
text = request.prompt + text
if request.logprobs:
# The first chunk and echo is enabled.
if not stream_buffer and request.echo:
prefill_token_logprobs = content["meta_info"][
"prefill_token_logprobs"
]
prefill_top_logprobs = content["meta_info"][
"prefill_top_logprobs"
]
else:
prefill_token_logprobs = None
prefill_top_logprobs = None
logprobs = to_openai_style_logprobs(
prefill_token_logprobs=prefill_token_logprobs,
prefill_top_logprobs=prefill_top_logprobs,
decode_token_logprobs=content["meta_info"][
"decode_token_logprobs"
][n_prev_token:],
decode_top_logprobs=content["meta_info"]["decode_top_logprobs"][
n_prev_token:
],
)
n_prev_token = len(content["meta_info"]["decode_token_logprobs"])
else:
logprobs = None
delta = text[len(stream_buffer) :]
stream_buffer = content["text"]
choice_data = CompletionResponseStreamChoice(
index=0,
text=delta,
logprobs=logprobs,
finish_reason=None,
)
chunk = CompletionStreamResponse(
id=content["meta_info"]["id"],
object="text_completion",
choices=[choice_data],
model=request.model,
usage=UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(generate_stream_resp(), media_type="text/event-stream")
# Non-streaming response.
ret = await tokenizer_manager.generate_request(adapted_request).__anext__()
ret = ret[0] if isinstance(ret, list) else ret
prompt_tokens = ret["meta_info"]["prompt_tokens"]
completion_tokens = ret["meta_info"]["completion_tokens"]
text = ret["text"]
if request.echo:
text = request.prompt + text
if request.logprobs:
if request.echo:
prefill_token_logprobs = ret["meta_info"]["prefill_token_logprobs"]
prefill_top_logprobs = ret["meta_info"]["prefill_top_logprobs"]
else:
prefill_token_logprobs = None
prefill_top_logprobs = None
logprobs = to_openai_style_logprobs(
prefill_token_logprobs=prefill_token_logprobs,
prefill_top_logprobs=prefill_top_logprobs,
decode_token_logprobs=ret["meta_info"]["decode_token_logprobs"],
decode_top_logprobs=ret["meta_info"]["decode_top_logprobs"],
)
else:
logprobs = None
choice_data = CompletionResponseChoice(
index=0,
text=text,
logprobs=logprobs,
finish_reason=None, # TODO(comaniac): Add finish reason.
)
response = CompletionResponse(
id=ret["meta_info"]["id"],
model=request.model,
choices=[choice_data],
usage=UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
return response
async def v1_chat_completions(tokenizer_manager, raw_request: Request):
request_json = await raw_request.json()
request = ChatCompletionRequest(**request_json)
# TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
assert request.n == 1
# Prep the data needed for the underlying GenerateReqInput:
# - prompt: The full prompt string.
# - stop: Custom stop tokens.
# - image_data: None or a list of image strings (URLs or base64 strings).
# None skips any image processing in GenerateReqInput.
if not isinstance(request.messages, str):
# Apply chat template and its stop strings.
if chat_template_name is None:
prompt = tokenizer_manager.tokenizer.apply_chat_template(
request.messages, tokenize=False, add_generation_prompt=True
)
stop = request.stop
image_data = None
else:
conv = generate_chat_conv(request, chat_template_name)
prompt = conv.get_prompt()
image_data = conv.image_data
stop = conv.stop_str or []
if request.stop:
if isinstance(request.stop, str):
stop.append(request.stop)
else:
stop.extend(request.stop)
else:
# Use the raw prompt and stop strings if the messages is already a string.
prompt = request.messages
stop = request.stop
image_data = None
adapted_request = GenerateReqInput(
text=prompt,
image_data=image_data,
sampling_params={
"temperature": request.temperature,
"max_new_tokens": request.max_tokens,
"stop": stop,
"top_p": request.top_p,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
"regex": request.regex,
},
stream=request.stream,
)
adapted_request.post_init()
if adapted_request.stream:
async def generate_stream_resp():
is_first = True
stream_buffer = ""
async for content in tokenizer_manager.generate_request(adapted_request):
if is_first:
# First chunk with role
is_first = False
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
choices=[choice_data],
model=request.model,
)
yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
text = content["text"]
delta = text[len(stream_buffer) :]
stream_buffer = text
choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=DeltaMessage(content=delta), finish_reason=None
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
choices=[choice_data],
model=request.model,
)
yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(generate_stream_resp(), media_type="text/event-stream")
# Non-streaming response.
ret = await tokenizer_manager.generate_request(adapted_request).__anext__()
prompt_tokens = ret["meta_info"]["prompt_tokens"]
completion_tokens = ret["meta_info"]["completion_tokens"]
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=ret["text"]),
finish_reason=None, # TODO(comaniac): Add finish reason.
)
response = ChatCompletionResponse(
id=ret["meta_info"]["id"],
model=request.model,
choices=[choice_data],
usage=UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
return response
def to_openai_style_logprobs(
prefill_token_logprobs=None,
decode_token_logprobs=None,
prefill_top_logprobs=None,
decode_top_logprobs=None,
):
ret_logprobs = LogProbs()
def append_token_logprobs(token_logprobs):
for logprob, _, token_text in token_logprobs:
ret_logprobs.tokens.append(token_text)
ret_logprobs.token_logprobs.append(logprob)
# Not Supported yet
ret_logprobs.text_offset.append(-1)
def append_top_logprobs(top_logprobs):
for tokens in top_logprobs:
if tokens is not None:
ret_logprobs.top_logprobs.append(
{token[2]: token[0] for token in tokens}
)
else:
ret_logprobs.top_logprobs.append(None)
if prefill_token_logprobs is not None:
append_token_logprobs(prefill_token_logprobs)
if decode_token_logprobs is not None:
append_token_logprobs(decode_token_logprobs)
if prefill_top_logprobs is not None:
append_top_logprobs(prefill_top_logprobs)
if decode_top_logprobs is not None:
append_top_logprobs(decode_top_logprobs)
return ret_logprobs
\ No newline at end of file
...@@ -20,21 +20,24 @@ class UsageInfo(BaseModel): ...@@ -20,21 +20,24 @@ class UsageInfo(BaseModel):
class CompletionRequest(BaseModel): class CompletionRequest(BaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/completions/create
model: str model: str
prompt: Union[str, List[str]] prompt: Union[List[int], List[List[int]], str, List[str]]
suffix: Optional[str] = None best_of: Optional[int] = None
max_tokens: Optional[int] = 16
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
n: Optional[int] = 1
stream: Optional[bool] = False
logprobs: Optional[int] = None
echo: Optional[bool] = False echo: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0
best_of: Optional[int] = None
logit_bias: Optional[Dict[str, float]] = None logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[int] = None
max_tokens: Optional[int] = 16
n: int = 1
presence_penalty: Optional[float] = 0.0
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
suffix: Optional[str] = None
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
user: Optional[str] = None user: Optional[str] = None
# Extra parameters for SRT backend only and will be ignored by OpenAI models. # Extra parameters for SRT backend only and will be ignored by OpenAI models.
...@@ -108,20 +111,30 @@ ChatCompletionMessageParam = Union[ ...@@ -108,20 +111,30 @@ ChatCompletionMessageParam = Union[
] ]
class ResponseFormat(BaseModel):
# type must be "json_object" or "text"
type: Literal["text", "json_object"]
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create
messages: List[ChatCompletionMessageParam]
model: str model: str
messages: Union[str, List[ChatCompletionMessageParam]] frequency_penalty: Optional[float] = 0.0
temperature: Optional[float] = 0.7 logit_bias: Optional[Dict[str, float]] = None
top_p: Optional[float] = 1.0 logprobs: Optional[bool] = False
top_logprobs: Optional[int] = None
max_tokens: Optional[int] = None
n: Optional[int] = 1 n: Optional[int] = 1
max_tokens: Optional[int] = 16 presence_penalty: Optional[float] = 0.0
response_format: Optional[ResponseFormat] = None
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False stream: Optional[bool] = False
presence_penalty: Optional[float] = 0.0 temperature: Optional[float] = 0.7
frequency_penalty: Optional[float] = 0.0 top_p: Optional[float] = 1.0
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None user: Optional[str] = None
best_of: Optional[int] = None
# Extra parameters for SRT backend only and will be ignored by OpenAI models. # Extra parameters for SRT backend only and will be ignored by OpenAI models.
regex: Optional[str] = None regex: Optional[str] = None
...@@ -135,6 +148,7 @@ class ChatMessage(BaseModel): ...@@ -135,6 +148,7 @@ class ChatMessage(BaseModel):
class ChatCompletionResponseChoice(BaseModel): class ChatCompletionResponseChoice(BaseModel):
index: int index: int
message: ChatMessage message: ChatMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[str] = None finish_reason: Optional[str] = None
...@@ -155,6 +169,7 @@ class DeltaMessage(BaseModel): ...@@ -155,6 +169,7 @@ class DeltaMessage(BaseModel):
class ChatCompletionResponseStreamChoice(BaseModel): class ChatCompletionResponseStreamChoice(BaseModel):
index: int index: int
delta: DeltaMessage delta: DeltaMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[str] = None finish_reason: Optional[str] = None
...@@ -163,4 +178,4 @@ class ChatCompletionStreamResponse(BaseModel): ...@@ -163,4 +178,4 @@ class ChatCompletionStreamResponse(BaseModel):
object: str = "chat.completion.chunk" object: str = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
model: str model: str
choices: List[ChatCompletionResponseStreamChoice] choices: List[ChatCompletionResponseStreamChoice]
\ No newline at end of file
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import asyncio import asyncio
import dataclasses import dataclasses
import json import json
import logging
import multiprocessing as mp import multiprocessing as mp
import os import os
import sys import sys
...@@ -18,45 +19,23 @@ import psutil ...@@ -18,45 +19,23 @@ import psutil
import requests import requests
import uvicorn import uvicorn
import uvloop import uvloop
from fastapi import FastAPI, HTTPException, Request from fastapi import FastAPI, Request
from fastapi.responses import Response, StreamingResponse from fastapi.responses import Response, StreamingResponse
from sglang.backend.runtime_endpoint import RuntimeEndpoint from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.constrained import disable_cache from sglang.srt.constrained import disable_cache
from sglang.srt.conversation import (
Conversation,
SeparatorStyle,
chat_template_exists,
generate_chat_conv,
register_conv_template,
)
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.io_struct import DetokenizeReqInput, GenerateReqInput from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.openai_protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatMessage,
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
DeltaMessage,
LogProbs,
UsageInfo,
)
from sglang.srt.managers.router.manager import start_router_process from sglang.srt.managers.router.manager import start_router_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.openai_api_adapter import (
v1_completions, v1_chat_completions, load_chat_template_for_openai_api)
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
allocate_init_ports, allocate_init_ports,
assert_pkg_version, assert_pkg_version,
enable_show_time_cost, enable_show_time_cost,
jsonify_pydantic_model,
get_exception_traceback, get_exception_traceback,
API_KEY_HEADER_NAME, API_KEY_HEADER_NAME,
APIKeyValidatorMiddleware APIKeyValidatorMiddleware
...@@ -67,7 +46,6 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) ...@@ -67,7 +46,6 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
app = FastAPI() app = FastAPI()
tokenizer_manager = None tokenizer_manager = None
chat_template_name = None
@app.get("/health") @app.get("/health")
...@@ -117,343 +95,23 @@ async def generate_request(obj: GenerateReqInput): ...@@ -117,343 +95,23 @@ async def generate_request(obj: GenerateReqInput):
@app.post("/v1/completions") @app.post("/v1/completions")
async def v1_completions(raw_request: Request): async def openai_v1_completions(raw_request: Request):
request_json = await raw_request.json() return await v1_completions(tokenizer_manager, raw_request)
request = CompletionRequest(**request_json)
# TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
assert request.n == 1
adapted_request = GenerateReqInput(
text=request.prompt,
sampling_params={
"temperature": request.temperature,
"max_new_tokens": request.max_tokens,
"stop": request.stop,
"top_p": request.top_p,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
"regex": request.regex,
},
return_logprob=request.logprobs is not None and request.logprobs > 0,
top_logprobs_num=request.logprobs if request.logprobs is not None else 0,
return_text_in_logprobs=True,
stream=request.stream,
)
adapted_request.post_init()
if adapted_request.stream:
async def generate_stream_resp():
stream_buffer = ""
n_prev_token = 0
async for content in tokenizer_manager.generate_request(adapted_request):
text = content["text"]
prompt_tokens = content["meta_info"]["prompt_tokens"]
completion_tokens = content["meta_info"]["completion_tokens"]
if not stream_buffer: # The first chunk
if request.echo:
# Prepend prompt in response text.
text = request.prompt + text
if request.logprobs:
# The first chunk and echo is enabled.
if not stream_buffer and request.echo:
prefill_token_logprobs = content["meta_info"][
"prefill_token_logprobs"
]
prefill_top_logprobs = content["meta_info"][
"prefill_top_logprobs"
]
else:
prefill_token_logprobs = None
prefill_top_logprobs = None
logprobs = await make_openai_style_logprobs(
prefill_token_logprobs=prefill_token_logprobs,
prefill_top_logprobs=prefill_top_logprobs,
decode_token_logprobs=content["meta_info"][
"decode_token_logprobs"
][n_prev_token:],
decode_top_logprobs=content["meta_info"]["decode_top_logprobs"][
n_prev_token:
],
)
n_prev_token = len(content["meta_info"]["decode_token_logprobs"])
else:
logprobs = None
delta = text[len(stream_buffer) :]
stream_buffer = content["text"]
choice_data = CompletionResponseStreamChoice(
index=0,
text=delta,
logprobs=logprobs,
finish_reason=None,
)
chunk = CompletionStreamResponse(
id=content["meta_info"]["id"],
object="text_completion",
choices=[choice_data],
model=request.model,
usage=UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(generate_stream_resp(), media_type="text/event-stream")
# Non-streaming response.
ret = await generate_request(adapted_request)
ret = ret[0] if isinstance(ret, list) else ret
prompt_tokens = ret["meta_info"]["prompt_tokens"]
completion_tokens = ret["meta_info"]["completion_tokens"]
text = ret["text"]
if request.echo:
text = request.prompt + text
if request.logprobs:
if request.echo:
prefill_token_logprobs = ret["meta_info"]["prefill_token_logprobs"]
prefill_top_logprobs = ret["meta_info"]["prefill_top_logprobs"]
else:
prefill_token_logprobs = None
prefill_top_logprobs = None
logprobs = await make_openai_style_logprobs(
prefill_token_logprobs=prefill_token_logprobs,
prefill_top_logprobs=prefill_top_logprobs,
decode_token_logprobs=ret["meta_info"]["decode_token_logprobs"],
decode_top_logprobs=ret["meta_info"]["decode_top_logprobs"],
)
else:
logprobs = None
choice_data = CompletionResponseChoice(
index=0,
text=text,
logprobs=logprobs,
finish_reason=None, # TODO(comaniac): Add finish reason.
)
response = CompletionResponse(
id=ret["meta_info"]["id"],
model=request.model,
choices=[choice_data],
usage=UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
return response
@app.post("/v1/chat/completions") @app.post("/v1/chat/completions")
async def v1_chat_completions(raw_request: Request): async def openai_v1_chat_completions(raw_request: Request):
request_json = await raw_request.json() return await v1_chat_completions(tokenizer_manager, raw_request)
request = ChatCompletionRequest(**request_json)
# TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
assert request.n == 1
# Prep the data needed for the underlying GenerateReqInput:
# - prompt: The full prompt string.
# - stop: Custom stop tokens.
# - image_data: None or a list of image strings (URLs or base64 strings).
# None skips any image processing in GenerateReqInput.
if not isinstance(request.messages, str):
# Apply chat template and its stop strings.
if chat_template_name is None:
# This flow doesn't support the full OpenAI spec. Verify messages
# has the right type before proceeding:
for m in request.messages:
if not isinstance(m.content, str):
raise HTTPException(
status_code=503,
detail="Structured content requests not supported with "
"HuggingFace Chat Templates. "
"Make sure the server specifies a sglang chat template.",
)
prompt = tokenizer_manager.tokenizer.apply_chat_template(
request.messages, tokenize=False, add_generation_prompt=True
)
stop = request.stop
image_data = None
else:
conv = generate_chat_conv(request, chat_template_name)
prompt = conv.get_prompt()
image_data = conv.image_data
stop = conv.stop_str or []
if request.stop:
if isinstance(request.stop, str):
stop.append(request.stop)
else:
stop.extend(request.stop)
else:
# Use the raw prompt and stop strings if the messages is already a string.
prompt = request.messages
stop = request.stop
image_data = None
adapted_request = GenerateReqInput(
text=prompt,
image_data=image_data,
sampling_params={
"temperature": request.temperature,
"max_new_tokens": request.max_tokens,
"stop": stop,
"top_p": request.top_p,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
"regex": request.regex,
},
stream=request.stream,
)
adapted_request.post_init()
if adapted_request.stream:
async def gnerate_stream_resp():
is_first = True
stream_buffer = ""
async for content in tokenizer_manager.generate_request(adapted_request):
if is_first:
# First chunk with role
is_first = False
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
choices=[choice_data],
model=request.model,
)
yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
text = content["text"]
delta = text[len(stream_buffer) :]
stream_buffer = text
choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=DeltaMessage(content=delta), finish_reason=None
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
choices=[choice_data],
model=request.model,
)
yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(gnerate_stream_resp(), media_type="text/event-stream")
# Non-streaming response.
ret = await generate_request(adapted_request)
prompt_tokens = ret["meta_info"]["prompt_tokens"]
completion_tokens = ret["meta_info"]["completion_tokens"]
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=ret["text"]),
finish_reason=None, # TODO(comaniac): Add finish reason.
)
response = ChatCompletionResponse(
id=ret["meta_info"]["id"],
model=request.model,
choices=[choice_data],
usage=UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
return response
async def make_openai_style_logprobs(
prefill_token_logprobs=None,
decode_token_logprobs=None,
prefill_top_logprobs=None,
decode_top_logprobs=None,
):
ret_logprobs = LogProbs()
def append_token_logprobs(token_logprobs):
for logprob, _, token_text in token_logprobs:
ret_logprobs.tokens.append(token_text)
ret_logprobs.token_logprobs.append(logprob)
# Not Supported yet
ret_logprobs.text_offset.append(-1)
def append_top_logprobs(top_logprobs):
for tokens in top_logprobs:
if tokens is not None:
ret_logprobs.top_logprobs.append(
{token[2]: token[0] for token in tokens}
)
else:
ret_logprobs.top_logprobs.append(None)
if prefill_token_logprobs is not None:
append_token_logprobs(prefill_token_logprobs)
if decode_token_logprobs is not None:
append_token_logprobs(decode_token_logprobs)
if prefill_top_logprobs is not None:
append_top_logprobs(prefill_top_logprobs)
if decode_top_logprobs is not None:
append_top_logprobs(decode_top_logprobs)
return ret_logprobs
def load_chat_template_for_openai_api(chat_template_arg):
global chat_template_name
print(f"Use chat template: {chat_template_arg}")
if not chat_template_exists(chat_template_arg):
if not os.path.exists(chat_template_arg):
raise RuntimeError(
f"Chat template {chat_template_arg} is not a built-in template name "
"or a valid chat template file path."
)
with open(chat_template_arg, "r") as filep:
template = json.load(filep)
try:
sep_style = SeparatorStyle[template["sep_style"]]
except KeyError:
raise ValueError(
f"Unknown separator style: {template['sep_style']}"
) from None
register_conv_template(
Conversation(
name=template["name"],
system_template=template["system"] + "\n{system_message}",
system_message=template.get("system_message", ""),
roles=(template["user"], template["assistant"]),
sep_style=sep_style,
sep=template.get("sep", "\n"),
stop_str=template["stop_str"],
),
override=True,
)
chat_template_name = template["name"]
else:
chat_template_name = chat_template_arg
def launch_server(server_args: ServerArgs, pipe_finish_writer): def launch_server(server_args: ServerArgs, pipe_finish_writer):
global tokenizer_manager global tokenizer_manager
logging.basicConfig(
level=getattr(logging, server_args.log_level.upper()),
format="%(message)s",
)
# Set global environments # Set global environments
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
if server_args.show_time_cost: if server_args.show_time_cost:
...@@ -656,4 +314,4 @@ class Runtime: ...@@ -656,4 +314,4 @@ class Runtime:
pos += len(cur) pos += len(cur)
def __del__(self): def __del__(self):
self.shutdown() self.shutdown()
\ No newline at end of file
...@@ -34,6 +34,7 @@ class ServerArgs: ...@@ -34,6 +34,7 @@ class ServerArgs:
# Logging # Logging
log_level: str = "info" log_level: str = "info"
log_requests: bool = False
disable_log_stats: bool = False disable_log_stats: bool = False
log_stats_interval: int = 10 log_stats_interval: int = 10
show_time_cost: bool = False show_time_cost: bool = False
...@@ -180,6 +181,11 @@ class ServerArgs: ...@@ -180,6 +181,11 @@ class ServerArgs:
default=ServerArgs.log_level, default=ServerArgs.log_level,
help="Logging level", help="Logging level",
) )
parser.add_argument(
"--log-requests",
action="store_true",
help="Log all requests",
)
parser.add_argument( parser.add_argument(
"--disable-log-stats", "--disable-log-stats",
action="store_true", action="store_true",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment