Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ca600e8c
Unverified
Commit
ca600e8c
authored
Aug 01, 2024
by
yichuan~
Committed by
GitHub
Aug 01, 2024
Browse files
Add support for logprobs in OpenAI chat API (#852)
parent
0c0c8137
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
116 additions
and
20 deletions
+116
-20
examples/usage/openai_parallel_sample.py
examples/usage/openai_parallel_sample.py
+28
-3
python/sglang/srt/openai_api/adapter.py
python/sglang/srt/openai_api/adapter.py
+68
-15
python/sglang/srt/openai_api/protocol.py
python/sglang/srt/openai_api/protocol.py
+20
-2
No files found.
examples/usage/openai_parallel_sample.py
View file @
ca600e8c
...
...
@@ -106,12 +106,24 @@ response = client.chat.completions.create(
{
"role"
:
"user"
,
"content"
:
"List 3 countries and their capitals."
},
],
temperature
=
0.8
,
max_tokens
=
64
,
max_tokens
=
1
,
logprobs
=
True
,
n
=
1
,
top_logprobs
=
3
,
)
print
(
response
)
# Chat completion
response
=
client
.
chat
.
completions
.
create
(
model
=
"default"
,
messages
=
[
{
"role"
:
"system"
,
"content"
:
"You are a helpful AI assistant"
},
{
"role"
:
"user"
,
"content"
:
"List 3 countries and their capitals."
},
],
temperature
=
0.8
,
max_tokens
=
1
,
n
=
1
,
)
print
(
response
)
# Chat completion
response
=
client
.
chat
.
completions
.
create
(
...
...
@@ -121,8 +133,21 @@ response = client.chat.completions.create(
{
"role"
:
"user"
,
"content"
:
"List 3 countries and their capitals."
},
],
temperature
=
0.8
,
max_tokens
=
64
,
max_tokens
=
1
,
logprobs
=
True
,
top_logprobs
=
3
,
)
print
(
response
)
# Chat completion
response
=
client
.
chat
.
completions
.
create
(
model
=
"default"
,
messages
=
[
{
"role"
:
"system"
,
"content"
:
"You are a helpful AI assistant"
},
{
"role"
:
"user"
,
"content"
:
"List 3 countries and their capitals."
},
],
temperature
=
0.8
,
max_tokens
=
1
,
n
=
4
,
)
print
(
response
)
python/sglang/srt/openai_api/adapter.py
View file @
ca600e8c
...
...
@@ -43,7 +43,9 @@ from sglang.srt.openai_api.protocol import (
ChatCompletionResponseChoice
,
ChatCompletionResponseStreamChoice
,
ChatCompletionStreamResponse
,
ChatCompletionTokenLogprob
,
ChatMessage
,
ChoiceLogprobs
,
CompletionRequest
,
CompletionResponse
,
CompletionResponseChoice
,
...
...
@@ -54,6 +56,7 @@ from sglang.srt.openai_api.protocol import (
FileRequest
,
FileResponse
,
LogProbs
,
TopLogprob
,
UsageInfo
,
)
...
...
@@ -70,7 +73,7 @@ class FileMetadata:
batch_storage
:
Dict
[
str
,
BatchResponse
]
=
{}
file_id_request
:
Dict
[
str
,
FileMetadata
]
=
{}
file_id_response
:
Dict
[
str
,
FileResponse
]
=
{}
#
# map file id to file path in SGlang backend
# map file id to file path in SGlang backend
file_id_storage
:
Dict
[
str
,
str
]
=
{}
...
...
@@ -261,7 +264,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
failed_requests
+=
len
(
file_request_list
)
for
idx
,
response
in
enumerate
(
responses
):
#
#
the batch_req here can be changed to be named within a batch granularity
# the batch_req here can be changed to be named within a batch granularity
response_json
=
{
"id"
:
f
"batch_req_
{
uuid
.
uuid4
()
}
"
,
"custom_id"
:
file_request_list
[
idx
].
get
(
"custom_id"
),
...
...
@@ -333,6 +336,8 @@ def v1_generate_request(all_requests):
prompts
=
[]
sampling_params_list
=
[]
return_logprobs
=
[]
top_logprobs_nums
=
[]
first_prompt_type
=
type
(
all_requests
[
0
].
prompt
)
for
request
in
all_requests
:
prompt
=
request
.
prompt
...
...
@@ -340,6 +345,10 @@ def v1_generate_request(all_requests):
type
(
prompt
)
==
first_prompt_type
),
"All prompts must be of the same type in file input settings"
prompts
.
append
(
prompt
)
return_logprobs
.
append
(
request
.
logprobs
is
not
None
and
request
.
logprobs
>
0
)
top_logprobs_nums
.
append
(
request
.
logprobs
if
request
.
logprobs
is
not
None
else
0
)
sampling_params_list
.
append
(
{
"temperature"
:
request
.
temperature
,
...
...
@@ -361,6 +370,8 @@ def v1_generate_request(all_requests):
if
len
(
all_requests
)
==
1
:
prompt
=
prompts
[
0
]
sampling_params_list
=
sampling_params_list
[
0
]
return_logprobs
=
return_logprobs
[
0
]
top_logprobs_nums
=
top_logprobs_nums
[
0
]
if
isinstance
(
prompt
,
str
)
or
isinstance
(
prompt
[
0
],
str
):
prompt_kwargs
=
{
"text"
:
prompt
}
else
:
...
...
@@ -370,15 +381,11 @@ def v1_generate_request(all_requests):
prompt_kwargs
=
{
"text"
:
prompts
}
else
:
prompt_kwargs
=
{
"input_ids"
:
prompts
}
adapted_request
=
GenerateReqInput
(
**
prompt_kwargs
,
sampling_params
=
sampling_params_list
,
return_logprob
=
all_requests
[
0
].
logprobs
is
not
None
and
all_requests
[
0
].
logprobs
>
0
,
top_logprobs_num
=
(
all_requests
[
0
].
logprobs
if
all_requests
[
0
].
logprobs
is
not
None
else
0
),
return_logprob
=
return_logprobs
,
top_logprobs_num
=
top_logprobs_nums
,
return_text_in_logprobs
=
True
,
stream
=
all_requests
[
0
].
stream
,
)
...
...
@@ -430,7 +437,7 @@ def v1_generate_response(request, ret, to_file=False):
logprobs
=
None
if
to_file
:
#
#
to make the choise data json serializable
# to make the choise data json serializable
choice_data
=
{
"index"
:
0
,
"text"
:
text
,
...
...
@@ -454,7 +461,7 @@ def v1_generate_response(request, ret, to_file=False):
"status_code"
:
200
,
"request_id"
:
ret
[
i
][
"meta_info"
][
"id"
],
"body"
:
{
#
#
remain the same but if needed we can change that
# remain the same but if needed we can change that
"id"
:
ret
[
i
][
"meta_info"
][
"id"
],
"object"
:
"text_completion"
,
"created"
:
int
(
time
.
time
()),
...
...
@@ -590,6 +597,8 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
texts
=
[]
sampling_params_list
=
[]
image_data_list
=
[]
return_logprobs
=
[]
top_logprobs_nums
=
[]
for
request
in
all_requests
:
# Prep the data needed for the underlying GenerateReqInput:
# - prompt: The full prompt string.
...
...
@@ -620,6 +629,8 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
stop
=
request
.
stop
image_data
=
None
texts
.
append
(
prompt
)
return_logprobs
.
append
(
request
.
logprobs
)
top_logprobs_nums
.
append
(
request
.
top_logprobs
)
sampling_params_list
.
append
(
{
"temperature"
:
request
.
temperature
,
...
...
@@ -637,11 +648,16 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
texts
=
texts
[
0
]
sampling_params_list
=
sampling_params_list
[
0
]
image_data
=
image_data_list
[
0
]
return_logprobs
=
return_logprobs
[
0
]
top_logprobs_nums
=
top_logprobs_nums
[
0
]
adapted_request
=
GenerateReqInput
(
text
=
texts
,
image_data
=
image_data
,
sampling_params
=
sampling_params_list
,
stream
=
request
.
stream
,
return_logprob
=
return_logprobs
,
top_logprobs_num
=
top_logprobs_nums
,
stream
=
all_requests
[
0
].
stream
,
return_text_in_logprobs
=
True
,
)
if
len
(
all_requests
)
==
1
:
return
adapted_request
,
all_requests
[
0
]
...
...
@@ -654,26 +670,63 @@ def v1_chat_generate_response(request, ret, to_file=False):
total_completion_tokens
=
0
for
idx
,
ret_item
in
enumerate
(
ret
):
logprobs
=
False
if
isinstance
(
request
,
List
)
and
request
[
idx
].
logprobs
:
logprobs
=
True
elif
(
not
isinstance
(
request
,
List
))
and
request
.
logprobs
:
logprobs
=
True
if
logprobs
:
logprobs
=
to_openai_style_logprobs
(
output_token_logprobs
=
ret_item
[
"meta_info"
][
"output_token_logprobs"
],
output_top_logprobs
=
ret_item
[
"meta_info"
][
"output_top_logprobs"
],
)
token_logprobs
=
[]
for
token
,
logprob
in
zip
(
logprobs
.
tokens
,
logprobs
.
token_logprobs
):
token_bytes
=
list
(
token
.
encode
(
"utf-8"
))
top_logprobs
=
[]
if
logprobs
.
top_logprobs
:
for
top_token
,
top_logprob
in
logprobs
.
top_logprobs
[
0
].
items
():
top_token_bytes
=
list
(
top_token
.
encode
(
"utf-8"
))
top_logprobs
.
append
(
TopLogprob
(
token
=
top_token
,
bytes
=
top_token_bytes
,
logprob
=
top_logprob
,
)
)
token_logprobs
.
append
(
ChatCompletionTokenLogprob
(
token
=
token
,
bytes
=
token_bytes
,
logprob
=
logprob
,
top_logprobs
=
top_logprobs
,
)
)
choice_logprobs
=
ChoiceLogprobs
(
content
=
token_logprobs
)
else
:
choice_logprobs
=
None
prompt_tokens
=
ret_item
[
"meta_info"
][
"prompt_tokens"
]
completion_tokens
=
ret_item
[
"meta_info"
][
"completion_tokens"
]
if
to_file
:
#
#
to make the choice data json serializable
# to make the choice data json serializable
choice_data
=
{
"index"
:
0
,
"message"
:
{
"role"
:
"assistant"
,
"content"
:
ret_item
[
"text"
]},
"logprobs"
:
None
,
"logprobs"
:
choice_logprobs
,
"finish_reason"
:
ret_item
[
"meta_info"
][
"finish_reason"
],
}
else
:
choice_data
=
ChatCompletionResponseChoice
(
index
=
idx
,
message
=
ChatMessage
(
role
=
"assistant"
,
content
=
ret_item
[
"text"
]),
logprobs
=
choice_logprobs
,
finish_reason
=
ret_item
[
"meta_info"
][
"finish_reason"
],
)
choices
.
append
(
choice_data
)
total_prompt_tokens
=
prompt_tokens
total_prompt_tokens
+
=
prompt_tokens
total_completion_tokens
+=
completion_tokens
if
to_file
:
responses
=
[]
...
...
@@ -683,7 +736,7 @@ def v1_chat_generate_response(request, ret, to_file=False):
"status_code"
:
200
,
"request_id"
:
ret
[
i
][
"meta_info"
][
"id"
],
"body"
:
{
#
#
remain the same but if needed we can change that
# remain the same but if needed we can change that
"id"
:
ret
[
i
][
"meta_info"
][
"id"
],
"object"
:
"chat.completion"
,
"created"
:
int
(
time
.
time
()),
...
...
python/sglang/srt/openai_api/protocol.py
View file @
ca600e8c
...
...
@@ -54,6 +54,24 @@ class LogProbs(BaseModel):
top_logprobs
:
List
[
Optional
[
Dict
[
str
,
float
]]]
=
Field
(
default_factory
=
list
)
class
TopLogprob
(
BaseModel
):
token
:
str
bytes
:
List
[
int
]
logprob
:
float
class
ChatCompletionTokenLogprob
(
BaseModel
):
token
:
str
bytes
:
List
[
int
]
logprob
:
float
top_logprobs
:
List
[
TopLogprob
]
class
ChoiceLogprobs
(
BaseModel
):
# build for v1/chat/completions response
content
:
List
[
ChatCompletionTokenLogprob
]
class
UsageInfo
(
BaseModel
):
prompt_tokens
:
int
=
0
total_tokens
:
int
=
0
...
...
@@ -239,8 +257,8 @@ class ChatMessage(BaseModel):
class
ChatCompletionResponseChoice
(
BaseModel
):
index
:
int
message
:
ChatMessage
logprobs
:
Optional
[
LogProbs
]
=
None
finish_reason
:
Optional
[
str
]
=
None
logprobs
:
Optional
[
Union
[
LogProbs
,
ChoiceLogprobs
]
]
=
None
finish_reason
:
str
class
ChatCompletionResponse
(
BaseModel
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment