Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
70f3e8e3
Unverified
Commit
70f3e8e3
authored
Feb 25, 2024
by
Jared Moore
Committed by
GitHub
Feb 26, 2024
Browse files
Add LogProbs for Chat Completions in OpenAI (#2918)
parent
ef978fe4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
57 additions
and
14 deletions
+57
-14
tests/entrypoints/test_openai_server.py
tests/entrypoints/test_openai_server.py
+13
-12
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+8
-0
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+36
-2
No files found.
tests/entrypoints/test_openai_server.py
View file @
70f3e8e3
...
@@ -155,15 +155,18 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
...
@@ -155,15 +155,18 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
}]
}]
# test single completion
# test single completion
chat_completion
=
await
client
.
chat
.
completions
.
create
(
chat_completion
=
await
client
.
chat
.
completions
.
create
(
model
=
model_name
,
model
=
model_name
,
messages
=
messages
,
messages
=
messages
,
max_tokens
=
10
,
max_tokens
=
10
,
logprobs
=
True
,
)
top_logprobs
=
10
)
assert
chat_completion
.
id
is
not
None
assert
chat_completion
.
id
is
not
None
assert
chat_completion
.
choices
is
not
None
and
len
(
assert
chat_completion
.
choices
is
not
None
and
len
(
chat_completion
.
choices
)
==
1
chat_completion
.
choices
)
==
1
assert
chat_completion
.
choices
[
0
].
message
is
not
None
assert
chat_completion
.
choices
[
0
].
message
is
not
None
assert
chat_completion
.
choices
[
0
].
logprobs
is
not
None
assert
chat_completion
.
choices
[
0
].
logprobs
.
top_logprobs
is
not
None
assert
len
(
chat_completion
.
choices
[
0
].
logprobs
.
top_logprobs
[
0
])
==
10
message
=
chat_completion
.
choices
[
0
].
message
message
=
chat_completion
.
choices
[
0
].
message
assert
message
.
content
is
not
None
and
len
(
message
.
content
)
>=
10
assert
message
.
content
is
not
None
and
len
(
message
.
content
)
>=
10
assert
message
.
role
==
"assistant"
assert
message
.
role
==
"assistant"
...
@@ -198,13 +201,11 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI,
...
@@ -198,13 +201,11 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI,
single_output
=
single_completion
.
choices
[
0
].
text
single_output
=
single_completion
.
choices
[
0
].
text
single_usage
=
single_completion
.
usage
single_usage
=
single_completion
.
usage
stream
=
await
client
.
completions
.
create
(
stream
=
await
client
.
completions
.
create
(
model
=
model_name
,
model
=
model_name
,
prompt
=
prompt
,
prompt
=
prompt
,
max_tokens
=
5
,
max_tokens
=
5
,
temperature
=
0.0
,
temperature
=
0.0
,
stream
=
True
)
stream
=
True
,
)
chunks
=
[]
chunks
=
[]
async
for
chunk
in
stream
:
async
for
chunk
in
stream
:
chunks
.
append
(
chunk
.
choices
[
0
].
text
)
chunks
.
append
(
chunk
.
choices
[
0
].
text
)
...
...
vllm/entrypoints/openai/protocol.py
View file @
70f3e8e3
...
@@ -63,6 +63,8 @@ class ChatCompletionRequest(BaseModel):
...
@@ -63,6 +63,8 @@ class ChatCompletionRequest(BaseModel):
seed
:
Optional
[
int
]
=
None
seed
:
Optional
[
int
]
=
None
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Field
(
default_factory
=
list
)
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Field
(
default_factory
=
list
)
stream
:
Optional
[
bool
]
=
False
stream
:
Optional
[
bool
]
=
False
logprobs
:
Optional
[
bool
]
=
False
top_logprobs
:
Optional
[
int
]
=
None
presence_penalty
:
Optional
[
float
]
=
0.0
presence_penalty
:
Optional
[
float
]
=
0.0
frequency_penalty
:
Optional
[
float
]
=
0.0
frequency_penalty
:
Optional
[
float
]
=
0.0
logit_bias
:
Optional
[
Dict
[
str
,
float
]]
=
None
logit_bias
:
Optional
[
Dict
[
str
,
float
]]
=
None
...
@@ -84,6 +86,8 @@ class ChatCompletionRequest(BaseModel):
...
@@ -84,6 +86,8 @@ class ChatCompletionRequest(BaseModel):
length_penalty
:
Optional
[
float
]
=
1.0
length_penalty
:
Optional
[
float
]
=
1.0
def
to_sampling_params
(
self
)
->
SamplingParams
:
def
to_sampling_params
(
self
)
->
SamplingParams
:
if
self
.
logprobs
and
not
self
.
top_logprobs
:
raise
ValueError
(
"Top logprobs must be set when logprobs is."
)
return
SamplingParams
(
return
SamplingParams
(
n
=
self
.
n
,
n
=
self
.
n
,
presence_penalty
=
self
.
presence_penalty
,
presence_penalty
=
self
.
presence_penalty
,
...
@@ -96,6 +100,8 @@ class ChatCompletionRequest(BaseModel):
...
@@ -96,6 +100,8 @@ class ChatCompletionRequest(BaseModel):
stop
=
self
.
stop
,
stop
=
self
.
stop
,
stop_token_ids
=
self
.
stop_token_ids
,
stop_token_ids
=
self
.
stop_token_ids
,
max_tokens
=
self
.
max_tokens
,
max_tokens
=
self
.
max_tokens
,
logprobs
=
self
.
top_logprobs
if
self
.
logprobs
else
None
,
prompt_logprobs
=
self
.
top_logprobs
if
self
.
echo
else
None
,
best_of
=
self
.
best_of
,
best_of
=
self
.
best_of
,
top_k
=
self
.
top_k
,
top_k
=
self
.
top_k
,
ignore_eos
=
self
.
ignore_eos
,
ignore_eos
=
self
.
ignore_eos
,
...
@@ -216,6 +222,7 @@ class ChatMessage(BaseModel):
...
@@ -216,6 +222,7 @@ class ChatMessage(BaseModel):
class
ChatCompletionResponseChoice
(
BaseModel
):
class
ChatCompletionResponseChoice
(
BaseModel
):
index
:
int
index
:
int
message
:
ChatMessage
message
:
ChatMessage
logprobs
:
Optional
[
LogProbs
]
=
None
finish_reason
:
Optional
[
Literal
[
"stop"
,
"length"
]]
=
None
finish_reason
:
Optional
[
Literal
[
"stop"
,
"length"
]]
=
None
...
@@ -236,6 +243,7 @@ class DeltaMessage(BaseModel):
...
@@ -236,6 +243,7 @@ class DeltaMessage(BaseModel):
class
ChatCompletionResponseStreamChoice
(
BaseModel
):
class
ChatCompletionResponseStreamChoice
(
BaseModel
):
index
:
int
index
:
int
delta
:
DeltaMessage
delta
:
DeltaMessage
logprobs
:
Optional
[
LogProbs
]
=
None
finish_reason
:
Optional
[
Literal
[
"stop"
,
"length"
]]
=
None
finish_reason
:
Optional
[
Literal
[
"stop"
,
"length"
]]
=
None
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
70f3e8e3
...
@@ -101,7 +101,10 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -101,7 +101,10 @@ class OpenAIServingChat(OpenAIServing):
role
=
self
.
get_chat_request_role
(
request
)
role
=
self
.
get_chat_request_role
(
request
)
for
i
in
range
(
request
.
n
):
for
i
in
range
(
request
.
n
):
choice_data
=
ChatCompletionResponseStreamChoice
(
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
i
,
delta
=
DeltaMessage
(
role
=
role
),
finish_reason
=
None
)
index
=
i
,
delta
=
DeltaMessage
(
role
=
role
),
logprobs
=
None
,
finish_reason
=
None
)
chunk
=
ChatCompletionStreamResponse
(
id
=
request_id
,
chunk
=
ChatCompletionStreamResponse
(
id
=
request_id
,
object
=
chunk_object_type
,
object
=
chunk_object_type
,
created
=
created_time
,
created
=
created_time
,
...
@@ -118,6 +121,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -118,6 +121,7 @@ class OpenAIServingChat(OpenAIServing):
"content"
)
and
request
.
messages
[
-
1
].
get
(
"content"
)
and
request
.
messages
[
-
1
].
get
(
"role"
)
==
role
:
"role"
)
==
role
:
last_msg_content
=
request
.
messages
[
-
1
][
"content"
]
last_msg_content
=
request
.
messages
[
-
1
][
"content"
]
if
last_msg_content
:
if
last_msg_content
:
for
i
in
range
(
request
.
n
):
for
i
in
range
(
request
.
n
):
choice_data
=
ChatCompletionResponseStreamChoice
(
choice_data
=
ChatCompletionResponseStreamChoice
(
...
@@ -129,6 +133,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -129,6 +133,7 @@ class OpenAIServingChat(OpenAIServing):
object
=
chunk_object_type
,
object
=
chunk_object_type
,
created
=
created_time
,
created
=
created_time
,
choices
=
[
choice_data
],
choices
=
[
choice_data
],
logprobs
=
None
,
model
=
model_name
)
model
=
model_name
)
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
yield
f
"data:
{
data
}
\n\n
"
yield
f
"data:
{
data
}
\n\n
"
...
@@ -145,15 +150,29 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -145,15 +150,29 @@ class OpenAIServingChat(OpenAIServing):
if
finish_reason_sent
[
i
]:
if
finish_reason_sent
[
i
]:
continue
continue
delta_token_ids
=
output
.
token_ids
[
previous_num_tokens
[
i
]:]
top_logprobs
=
output
.
logprobs
[
previous_num_tokens
[
i
]:]
if
output
.
logprobs
else
None
if
request
.
logprobs
:
logprobs
=
self
.
_create_logprobs
(
token_ids
=
delta_token_ids
,
top_logprobs
=
top_logprobs
,
num_output_top_logprobs
=
request
.
logprobs
,
initial_text_offset
=
len
(
previous_texts
[
i
]),
)
else
:
logprobs
=
None
delta_text
=
output
.
text
[
len
(
previous_texts
[
i
]):]
delta_text
=
output
.
text
[
len
(
previous_texts
[
i
]):]
previous_texts
[
i
]
=
output
.
text
previous_texts
[
i
]
=
output
.
text
previous_num_tokens
[
i
]
=
len
(
output
.
token_ids
)
previous_num_tokens
[
i
]
=
len
(
output
.
token_ids
)
if
output
.
finish_reason
is
None
:
if
output
.
finish_reason
is
None
:
# Send token-by-token response for each request.n
# Send token-by-token response for each request.n
choice_data
=
ChatCompletionResponseStreamChoice
(
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
i
,
index
=
i
,
delta
=
DeltaMessage
(
content
=
delta_text
),
delta
=
DeltaMessage
(
content
=
delta_text
),
logprobs
=
logprobs
,
finish_reason
=
None
)
finish_reason
=
None
)
chunk
=
ChatCompletionStreamResponse
(
chunk
=
ChatCompletionStreamResponse
(
id
=
request_id
,
id
=
request_id
,
...
@@ -174,6 +193,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -174,6 +193,7 @@ class OpenAIServingChat(OpenAIServing):
choice_data
=
ChatCompletionResponseStreamChoice
(
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
i
,
index
=
i
,
delta
=
DeltaMessage
(
content
=
delta_text
),
delta
=
DeltaMessage
(
content
=
delta_text
),
logprobs
=
logprobs
,
finish_reason
=
output
.
finish_reason
)
finish_reason
=
output
.
finish_reason
)
chunk
=
ChatCompletionStreamResponse
(
chunk
=
ChatCompletionStreamResponse
(
id
=
request_id
,
id
=
request_id
,
...
@@ -208,11 +228,25 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -208,11 +228,25 @@ class OpenAIServingChat(OpenAIServing):
assert
final_res
is
not
None
assert
final_res
is
not
None
choices
=
[]
choices
=
[]
role
=
self
.
get_chat_request_role
(
request
)
role
=
self
.
get_chat_request_role
(
request
)
for
output
in
final_res
.
outputs
:
for
output
in
final_res
.
outputs
:
token_ids
=
output
.
token_ids
top_logprobs
=
output
.
logprobs
if
request
.
logprobs
:
logprobs
=
self
.
_create_logprobs
(
token_ids
=
token_ids
,
top_logprobs
=
top_logprobs
,
num_output_top_logprobs
=
request
.
logprobs
,
)
else
:
logprobs
=
None
choice_data
=
ChatCompletionResponseChoice
(
choice_data
=
ChatCompletionResponseChoice
(
index
=
output
.
index
,
index
=
output
.
index
,
message
=
ChatMessage
(
role
=
role
,
content
=
output
.
text
),
message
=
ChatMessage
(
role
=
role
,
content
=
output
.
text
),
logprobs
=
logprobs
,
finish_reason
=
output
.
finish_reason
,
finish_reason
=
output
.
finish_reason
,
)
)
choices
.
append
(
choice_data
)
choices
.
append
(
choice_data
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment