Unverified Commit ca600e8c authored by yichuan~'s avatar yichuan~ Committed by GitHub
Browse files

Add support for logprobs in OpenAI chat API (#852)

parent 0c0c8137
...@@ -106,12 +106,24 @@ response = client.chat.completions.create( ...@@ -106,12 +106,24 @@ response = client.chat.completions.create(
{"role": "user", "content": "List 3 countries and their capitals."}, {"role": "user", "content": "List 3 countries and their capitals."},
], ],
temperature=0.8, temperature=0.8,
max_tokens=64, max_tokens=1,
logprobs=True, logprobs=True,
n=1, top_logprobs=3,
) )
print(response) print(response)
# Chat completion
response = client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0.8,
max_tokens=1,
n=1,
)
print(response)
# Chat completion # Chat completion
response = client.chat.completions.create( response = client.chat.completions.create(
...@@ -121,8 +133,21 @@ response = client.chat.completions.create( ...@@ -121,8 +133,21 @@ response = client.chat.completions.create(
{"role": "user", "content": "List 3 countries and their capitals."}, {"role": "user", "content": "List 3 countries and their capitals."},
], ],
temperature=0.8, temperature=0.8,
max_tokens=64, max_tokens=1,
logprobs=True, logprobs=True,
top_logprobs=3,
)
print(response)
# Chat completion
response = client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0.8,
max_tokens=1,
n=4, n=4,
) )
print(response) print(response)
...@@ -43,7 +43,9 @@ from sglang.srt.openai_api.protocol import ( ...@@ -43,7 +43,9 @@ from sglang.srt.openai_api.protocol import (
ChatCompletionResponseChoice, ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatCompletionStreamResponse,
ChatCompletionTokenLogprob,
ChatMessage, ChatMessage,
ChoiceLogprobs,
CompletionRequest, CompletionRequest,
CompletionResponse, CompletionResponse,
CompletionResponseChoice, CompletionResponseChoice,
...@@ -54,6 +56,7 @@ from sglang.srt.openai_api.protocol import ( ...@@ -54,6 +56,7 @@ from sglang.srt.openai_api.protocol import (
FileRequest, FileRequest,
FileResponse, FileResponse,
LogProbs, LogProbs,
TopLogprob,
UsageInfo, UsageInfo,
) )
...@@ -70,7 +73,7 @@ class FileMetadata: ...@@ -70,7 +73,7 @@ class FileMetadata:
batch_storage: Dict[str, BatchResponse] = {} batch_storage: Dict[str, BatchResponse] = {}
file_id_request: Dict[str, FileMetadata] = {} file_id_request: Dict[str, FileMetadata] = {}
file_id_response: Dict[str, FileResponse] = {} file_id_response: Dict[str, FileResponse] = {}
## map file id to file path in SGlang backend # map file id to file path in SGlang backend
file_id_storage: Dict[str, str] = {} file_id_storage: Dict[str, str] = {}
...@@ -261,7 +264,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe ...@@ -261,7 +264,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
failed_requests += len(file_request_list) failed_requests += len(file_request_list)
for idx, response in enumerate(responses): for idx, response in enumerate(responses):
## the batch_req here can be changed to be named within a batch granularity # the batch_req here can be changed to be named within a batch granularity
response_json = { response_json = {
"id": f"batch_req_{uuid.uuid4()}", "id": f"batch_req_{uuid.uuid4()}",
"custom_id": file_request_list[idx].get("custom_id"), "custom_id": file_request_list[idx].get("custom_id"),
...@@ -333,6 +336,8 @@ def v1_generate_request(all_requests): ...@@ -333,6 +336,8 @@ def v1_generate_request(all_requests):
prompts = [] prompts = []
sampling_params_list = [] sampling_params_list = []
return_logprobs = []
top_logprobs_nums = []
first_prompt_type = type(all_requests[0].prompt) first_prompt_type = type(all_requests[0].prompt)
for request in all_requests: for request in all_requests:
prompt = request.prompt prompt = request.prompt
...@@ -340,6 +345,10 @@ def v1_generate_request(all_requests): ...@@ -340,6 +345,10 @@ def v1_generate_request(all_requests):
type(prompt) == first_prompt_type type(prompt) == first_prompt_type
), "All prompts must be of the same type in file input settings" ), "All prompts must be of the same type in file input settings"
prompts.append(prompt) prompts.append(prompt)
return_logprobs.append(request.logprobs is not None and request.logprobs > 0)
top_logprobs_nums.append(
request.logprobs if request.logprobs is not None else 0
)
sampling_params_list.append( sampling_params_list.append(
{ {
"temperature": request.temperature, "temperature": request.temperature,
...@@ -361,6 +370,8 @@ def v1_generate_request(all_requests): ...@@ -361,6 +370,8 @@ def v1_generate_request(all_requests):
if len(all_requests) == 1: if len(all_requests) == 1:
prompt = prompts[0] prompt = prompts[0]
sampling_params_list = sampling_params_list[0] sampling_params_list = sampling_params_list[0]
return_logprobs = return_logprobs[0]
top_logprobs_nums = top_logprobs_nums[0]
if isinstance(prompt, str) or isinstance(prompt[0], str): if isinstance(prompt, str) or isinstance(prompt[0], str):
prompt_kwargs = {"text": prompt} prompt_kwargs = {"text": prompt}
else: else:
...@@ -370,15 +381,11 @@ def v1_generate_request(all_requests): ...@@ -370,15 +381,11 @@ def v1_generate_request(all_requests):
prompt_kwargs = {"text": prompts} prompt_kwargs = {"text": prompts}
else: else:
prompt_kwargs = {"input_ids": prompts} prompt_kwargs = {"input_ids": prompts}
adapted_request = GenerateReqInput( adapted_request = GenerateReqInput(
**prompt_kwargs, **prompt_kwargs,
sampling_params=sampling_params_list, sampling_params=sampling_params_list,
return_logprob=all_requests[0].logprobs is not None return_logprob=return_logprobs,
and all_requests[0].logprobs > 0, top_logprobs_num=top_logprobs_nums,
top_logprobs_num=(
all_requests[0].logprobs if all_requests[0].logprobs is not None else 0
),
return_text_in_logprobs=True, return_text_in_logprobs=True,
stream=all_requests[0].stream, stream=all_requests[0].stream,
) )
...@@ -430,7 +437,7 @@ def v1_generate_response(request, ret, to_file=False): ...@@ -430,7 +437,7 @@ def v1_generate_response(request, ret, to_file=False):
logprobs = None logprobs = None
if to_file: if to_file:
## to make the choise data json serializable # to make the choise data json serializable
choice_data = { choice_data = {
"index": 0, "index": 0,
"text": text, "text": text,
...@@ -454,7 +461,7 @@ def v1_generate_response(request, ret, to_file=False): ...@@ -454,7 +461,7 @@ def v1_generate_response(request, ret, to_file=False):
"status_code": 200, "status_code": 200,
"request_id": ret[i]["meta_info"]["id"], "request_id": ret[i]["meta_info"]["id"],
"body": { "body": {
## remain the same but if needed we can change that # remain the same but if needed we can change that
"id": ret[i]["meta_info"]["id"], "id": ret[i]["meta_info"]["id"],
"object": "text_completion", "object": "text_completion",
"created": int(time.time()), "created": int(time.time()),
...@@ -590,6 +597,8 @@ def v1_chat_generate_request(all_requests, tokenizer_manager): ...@@ -590,6 +597,8 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
texts = [] texts = []
sampling_params_list = [] sampling_params_list = []
image_data_list = [] image_data_list = []
return_logprobs = []
top_logprobs_nums = []
for request in all_requests: for request in all_requests:
# Prep the data needed for the underlying GenerateReqInput: # Prep the data needed for the underlying GenerateReqInput:
# - prompt: The full prompt string. # - prompt: The full prompt string.
...@@ -620,6 +629,8 @@ def v1_chat_generate_request(all_requests, tokenizer_manager): ...@@ -620,6 +629,8 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
stop = request.stop stop = request.stop
image_data = None image_data = None
texts.append(prompt) texts.append(prompt)
return_logprobs.append(request.logprobs)
top_logprobs_nums.append(request.top_logprobs)
sampling_params_list.append( sampling_params_list.append(
{ {
"temperature": request.temperature, "temperature": request.temperature,
...@@ -637,11 +648,16 @@ def v1_chat_generate_request(all_requests, tokenizer_manager): ...@@ -637,11 +648,16 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
texts = texts[0] texts = texts[0]
sampling_params_list = sampling_params_list[0] sampling_params_list = sampling_params_list[0]
image_data = image_data_list[0] image_data = image_data_list[0]
return_logprobs = return_logprobs[0]
top_logprobs_nums = top_logprobs_nums[0]
adapted_request = GenerateReqInput( adapted_request = GenerateReqInput(
text=texts, text=texts,
image_data=image_data, image_data=image_data,
sampling_params=sampling_params_list, sampling_params=sampling_params_list,
stream=request.stream, return_logprob=return_logprobs,
top_logprobs_num=top_logprobs_nums,
stream=all_requests[0].stream,
return_text_in_logprobs=True,
) )
if len(all_requests) == 1: if len(all_requests) == 1:
return adapted_request, all_requests[0] return adapted_request, all_requests[0]
...@@ -654,26 +670,63 @@ def v1_chat_generate_response(request, ret, to_file=False): ...@@ -654,26 +670,63 @@ def v1_chat_generate_response(request, ret, to_file=False):
total_completion_tokens = 0 total_completion_tokens = 0
for idx, ret_item in enumerate(ret): for idx, ret_item in enumerate(ret):
logprobs = False
if isinstance(request, List) and request[idx].logprobs:
logprobs = True
elif (not isinstance(request, List)) and request.logprobs:
logprobs = True
if logprobs:
logprobs = to_openai_style_logprobs(
output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"],
output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
)
token_logprobs = []
for token, logprob in zip(logprobs.tokens, logprobs.token_logprobs):
token_bytes = list(token.encode("utf-8"))
top_logprobs = []
if logprobs.top_logprobs:
for top_token, top_logprob in logprobs.top_logprobs[0].items():
top_token_bytes = list(top_token.encode("utf-8"))
top_logprobs.append(
TopLogprob(
token=top_token,
bytes=top_token_bytes,
logprob=top_logprob,
)
)
token_logprobs.append(
ChatCompletionTokenLogprob(
token=token,
bytes=token_bytes,
logprob=logprob,
top_logprobs=top_logprobs,
)
)
choice_logprobs = ChoiceLogprobs(content=token_logprobs)
else:
choice_logprobs = None
prompt_tokens = ret_item["meta_info"]["prompt_tokens"] prompt_tokens = ret_item["meta_info"]["prompt_tokens"]
completion_tokens = ret_item["meta_info"]["completion_tokens"] completion_tokens = ret_item["meta_info"]["completion_tokens"]
if to_file: if to_file:
## to make the choice data json serializable # to make the choice data json serializable
choice_data = { choice_data = {
"index": 0, "index": 0,
"message": {"role": "assistant", "content": ret_item["text"]}, "message": {"role": "assistant", "content": ret_item["text"]},
"logprobs": None, "logprobs": choice_logprobs,
"finish_reason": ret_item["meta_info"]["finish_reason"], "finish_reason": ret_item["meta_info"]["finish_reason"],
} }
else: else:
choice_data = ChatCompletionResponseChoice( choice_data = ChatCompletionResponseChoice(
index=idx, index=idx,
message=ChatMessage(role="assistant", content=ret_item["text"]), message=ChatMessage(role="assistant", content=ret_item["text"]),
logprobs=choice_logprobs,
finish_reason=ret_item["meta_info"]["finish_reason"], finish_reason=ret_item["meta_info"]["finish_reason"],
) )
choices.append(choice_data) choices.append(choice_data)
total_prompt_tokens = prompt_tokens total_prompt_tokens += prompt_tokens
total_completion_tokens += completion_tokens total_completion_tokens += completion_tokens
if to_file: if to_file:
responses = [] responses = []
...@@ -683,7 +736,7 @@ def v1_chat_generate_response(request, ret, to_file=False): ...@@ -683,7 +736,7 @@ def v1_chat_generate_response(request, ret, to_file=False):
"status_code": 200, "status_code": 200,
"request_id": ret[i]["meta_info"]["id"], "request_id": ret[i]["meta_info"]["id"],
"body": { "body": {
## remain the same but if needed we can change that # remain the same but if needed we can change that
"id": ret[i]["meta_info"]["id"], "id": ret[i]["meta_info"]["id"],
"object": "chat.completion", "object": "chat.completion",
"created": int(time.time()), "created": int(time.time()),
......
...@@ -54,6 +54,24 @@ class LogProbs(BaseModel): ...@@ -54,6 +54,24 @@ class LogProbs(BaseModel):
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list) top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
class TopLogprob(BaseModel):
token: str
bytes: List[int]
logprob: float
class ChatCompletionTokenLogprob(BaseModel):
token: str
bytes: List[int]
logprob: float
top_logprobs: List[TopLogprob]
class ChoiceLogprobs(BaseModel):
# build for v1/chat/completions response
content: List[ChatCompletionTokenLogprob]
class UsageInfo(BaseModel): class UsageInfo(BaseModel):
prompt_tokens: int = 0 prompt_tokens: int = 0
total_tokens: int = 0 total_tokens: int = 0
...@@ -239,8 +257,8 @@ class ChatMessage(BaseModel): ...@@ -239,8 +257,8 @@ class ChatMessage(BaseModel):
class ChatCompletionResponseChoice(BaseModel): class ChatCompletionResponseChoice(BaseModel):
index: int index: int
message: ChatMessage message: ChatMessage
logprobs: Optional[LogProbs] = None logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
finish_reason: Optional[str] = None finish_reason: str
class ChatCompletionResponse(BaseModel): class ChatCompletionResponse(BaseModel):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment