Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
70f3e8e3
Unverified
Commit
70f3e8e3
authored
Feb 25, 2024
by
Jared Moore
Committed by
GitHub
Feb 26, 2024
Browse files
Add LogProbs for Chat Completions in OpenAI (#2918)
parent
ef978fe4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
57 additions
and
14 deletions
+57
-14
tests/entrypoints/test_openai_server.py
tests/entrypoints/test_openai_server.py
+13
-12
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+8
-0
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+36
-2
No files found.
tests/entrypoints/test_openai_server.py
View file @
70f3e8e3
...
@@ -155,15 +155,18 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
...
@@ -155,15 +155,18 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
}]
}]
# test single completion
# test single completion
chat_completion
=
await
client
.
chat
.
completions
.
create
(
chat_completion
=
await
client
.
chat
.
completions
.
create
(
model
=
model_name
,
model
=
model_name
,
messages
=
messages
,
messages
=
messages
,
max_tokens
=
10
,
max_tokens
=
10
,
logprobs
=
True
,
)
top_logprobs
=
10
)
assert
chat_completion
.
id
is
not
None
assert
chat_completion
.
id
is
not
None
assert
chat_completion
.
choices
is
not
None
and
len
(
assert
chat_completion
.
choices
is
not
None
and
len
(
chat_completion
.
choices
)
==
1
chat_completion
.
choices
)
==
1
assert
chat_completion
.
choices
[
0
].
message
is
not
None
assert
chat_completion
.
choices
[
0
].
message
is
not
None
assert
chat_completion
.
choices
[
0
].
logprobs
is
not
None
assert
chat_completion
.
choices
[
0
].
logprobs
.
top_logprobs
is
not
None
assert
len
(
chat_completion
.
choices
[
0
].
logprobs
.
top_logprobs
[
0
])
==
10
message
=
chat_completion
.
choices
[
0
].
message
message
=
chat_completion
.
choices
[
0
].
message
assert
message
.
content
is
not
None
and
len
(
message
.
content
)
>=
10
assert
message
.
content
is
not
None
and
len
(
message
.
content
)
>=
10
assert
message
.
role
==
"assistant"
assert
message
.
role
==
"assistant"
...
@@ -198,13 +201,11 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI,
...
@@ -198,13 +201,11 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI,
single_output
=
single_completion
.
choices
[
0
].
text
single_output
=
single_completion
.
choices
[
0
].
text
single_usage
=
single_completion
.
usage
single_usage
=
single_completion
.
usage
stream
=
await
client
.
completions
.
create
(
stream
=
await
client
.
completions
.
create
(
model
=
model_name
,
model
=
model_name
,
prompt
=
prompt
,
prompt
=
prompt
,
max_tokens
=
5
,
max_tokens
=
5
,
temperature
=
0.0
,
temperature
=
0.0
,
stream
=
True
)
stream
=
True
,
)
chunks
=
[]
chunks
=
[]
async
for
chunk
in
stream
:
async
for
chunk
in
stream
:
chunks
.
append
(
chunk
.
choices
[
0
].
text
)
chunks
.
append
(
chunk
.
choices
[
0
].
text
)
...
...
vllm/entrypoints/openai/protocol.py
View file @
70f3e8e3
...
@@ -63,6 +63,8 @@ class ChatCompletionRequest(BaseModel):
...
@@ -63,6 +63,8 @@ class ChatCompletionRequest(BaseModel):
seed
:
Optional
[
int
]
=
None
seed
:
Optional
[
int
]
=
None
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Field
(
default_factory
=
list
)
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Field
(
default_factory
=
list
)
stream
:
Optional
[
bool
]
=
False
stream
:
Optional
[
bool
]
=
False
logprobs
:
Optional
[
bool
]
=
False
top_logprobs
:
Optional
[
int
]
=
None
presence_penalty
:
Optional
[
float
]
=
0.0
presence_penalty
:
Optional
[
float
]
=
0.0
frequency_penalty
:
Optional
[
float
]
=
0.0
frequency_penalty
:
Optional
[
float
]
=
0.0
logit_bias
:
Optional
[
Dict
[
str
,
float
]]
=
None
logit_bias
:
Optional
[
Dict
[
str
,
float
]]
=
None
...
@@ -84,6 +86,8 @@ class ChatCompletionRequest(BaseModel):
...
@@ -84,6 +86,8 @@ class ChatCompletionRequest(BaseModel):
length_penalty
:
Optional
[
float
]
=
1.0
length_penalty
:
Optional
[
float
]
=
1.0
def
to_sampling_params
(
self
)
->
SamplingParams
:
def
to_sampling_params
(
self
)
->
SamplingParams
:
if
self
.
logprobs
and
not
self
.
top_logprobs
:
raise
ValueError
(
"Top logprobs must be set when logprobs is."
)
return
SamplingParams
(
return
SamplingParams
(
n
=
self
.
n
,
n
=
self
.
n
,
presence_penalty
=
self
.
presence_penalty
,
presence_penalty
=
self
.
presence_penalty
,
...
@@ -96,6 +100,8 @@ class ChatCompletionRequest(BaseModel):
...
@@ -96,6 +100,8 @@ class ChatCompletionRequest(BaseModel):
stop
=
self
.
stop
,
stop
=
self
.
stop
,
stop_token_ids
=
self
.
stop_token_ids
,
stop_token_ids
=
self
.
stop_token_ids
,
max_tokens
=
self
.
max_tokens
,
max_tokens
=
self
.
max_tokens
,
logprobs
=
self
.
top_logprobs
if
self
.
logprobs
else
None
,
prompt_logprobs
=
self
.
top_logprobs
if
self
.
echo
else
None
,
best_of
=
self
.
best_of
,
best_of
=
self
.
best_of
,
top_k
=
self
.
top_k
,
top_k
=
self
.
top_k
,
ignore_eos
=
self
.
ignore_eos
,
ignore_eos
=
self
.
ignore_eos
,
...
@@ -216,6 +222,7 @@ class ChatMessage(BaseModel):
...
@@ -216,6 +222,7 @@ class ChatMessage(BaseModel):
class
ChatCompletionResponseChoice
(
BaseModel
):
class
ChatCompletionResponseChoice
(
BaseModel
):
index
:
int
index
:
int
message
:
ChatMessage
message
:
ChatMessage
logprobs
:
Optional
[
LogProbs
]
=
None
finish_reason
:
Optional
[
Literal
[
"stop"
,
"length"
]]
=
None
finish_reason
:
Optional
[
Literal
[
"stop"
,
"length"
]]
=
None
...
@@ -236,6 +243,7 @@ class DeltaMessage(BaseModel):
...
@@ -236,6 +243,7 @@ class DeltaMessage(BaseModel):
class
ChatCompletionResponseStreamChoice
(
BaseModel
):
class
ChatCompletionResponseStreamChoice
(
BaseModel
):
index
:
int
index
:
int
delta
:
DeltaMessage
delta
:
DeltaMessage
logprobs
:
Optional
[
LogProbs
]
=
None
finish_reason
:
Optional
[
Literal
[
"stop"
,
"length"
]]
=
None
finish_reason
:
Optional
[
Literal
[
"stop"
,
"length"
]]
=
None
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
70f3e8e3
...
@@ -101,7 +101,10 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -101,7 +101,10 @@ class OpenAIServingChat(OpenAIServing):
role
=
self
.
get_chat_request_role
(
request
)
role
=
self
.
get_chat_request_role
(
request
)
for
i
in
range
(
request
.
n
):
for
i
in
range
(
request
.
n
):
choice_data
=
ChatCompletionResponseStreamChoice
(
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
i
,
delta
=
DeltaMessage
(
role
=
role
),
finish_reason
=
None
)
index
=
i
,
delta
=
DeltaMessage
(
role
=
role
),
logprobs
=
None
,
finish_reason
=
None
)
chunk
=
ChatCompletionStreamResponse
(
id
=
request_id
,
chunk
=
ChatCompletionStreamResponse
(
id
=
request_id
,
object
=
chunk_object_type
,
object
=
chunk_object_type
,
created
=
created_time
,
created
=
created_time
,
...
@@ -118,6 +121,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -118,6 +121,7 @@ class OpenAIServingChat(OpenAIServing):
"content"
)
and
request
.
messages
[
-
1
].
get
(
"content"
)
and
request
.
messages
[
-
1
].
get
(
"role"
)
==
role
:
"role"
)
==
role
:
last_msg_content
=
request
.
messages
[
-
1
][
"content"
]
last_msg_content
=
request
.
messages
[
-
1
][
"content"
]
if
last_msg_content
:
if
last_msg_content
:
for
i
in
range
(
request
.
n
):
for
i
in
range
(
request
.
n
):
choice_data
=
ChatCompletionResponseStreamChoice
(
choice_data
=
ChatCompletionResponseStreamChoice
(
...
@@ -129,6 +133,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -129,6 +133,7 @@ class OpenAIServingChat(OpenAIServing):
object
=
chunk_object_type
,
object
=
chunk_object_type
,
created
=
created_time
,
created
=
created_time
,
choices
=
[
choice_data
],
choices
=
[
choice_data
],
logprobs
=
None
,
model
=
model_name
)
model
=
model_name
)
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
yield
f
"data:
{
data
}
\n\n
"
yield
f
"data:
{
data
}
\n\n
"
...
@@ -145,15 +150,29 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -145,15 +150,29 @@ class OpenAIServingChat(OpenAIServing):
if
finish_reason_sent
[
i
]:
if
finish_reason_sent
[
i
]:
continue
continue
delta_token_ids
=
output
.
token_ids
[
previous_num_tokens
[
i
]:]
top_logprobs
=
output
.
logprobs
[
previous_num_tokens
[
i
]:]
if
output
.
logprobs
else
None
if
request
.
logprobs
:
logprobs
=
self
.
_create_logprobs
(
token_ids
=
delta_token_ids
,
top_logprobs
=
top_logprobs
,
num_output_top_logprobs
=
request
.
logprobs
,
initial_text_offset
=
len
(
previous_texts
[
i
]),
)
else
:
logprobs
=
None
delta_text
=
output
.
text
[
len
(
previous_texts
[
i
]):]
delta_text
=
output
.
text
[
len
(
previous_texts
[
i
]):]
previous_texts
[
i
]
=
output
.
text
previous_texts
[
i
]
=
output
.
text
previous_num_tokens
[
i
]
=
len
(
output
.
token_ids
)
previous_num_tokens
[
i
]
=
len
(
output
.
token_ids
)
if
output
.
finish_reason
is
None
:
if
output
.
finish_reason
is
None
:
# Send token-by-token response for each request.n
# Send token-by-token response for each request.n
choice_data
=
ChatCompletionResponseStreamChoice
(
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
i
,
index
=
i
,
delta
=
DeltaMessage
(
content
=
delta_text
),
delta
=
DeltaMessage
(
content
=
delta_text
),
logprobs
=
logprobs
,
finish_reason
=
None
)
finish_reason
=
None
)
chunk
=
ChatCompletionStreamResponse
(
chunk
=
ChatCompletionStreamResponse
(
id
=
request_id
,
id
=
request_id
,
...
@@ -174,6 +193,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -174,6 +193,7 @@ class OpenAIServingChat(OpenAIServing):
choice_data
=
ChatCompletionResponseStreamChoice
(
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
i
,
index
=
i
,
delta
=
DeltaMessage
(
content
=
delta_text
),
delta
=
DeltaMessage
(
content
=
delta_text
),
logprobs
=
logprobs
,
finish_reason
=
output
.
finish_reason
)
finish_reason
=
output
.
finish_reason
)
chunk
=
ChatCompletionStreamResponse
(
chunk
=
ChatCompletionStreamResponse
(
id
=
request_id
,
id
=
request_id
,
...
@@ -208,11 +228,25 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -208,11 +228,25 @@ class OpenAIServingChat(OpenAIServing):
assert
final_res
is
not
None
assert
final_res
is
not
None
choices
=
[]
choices
=
[]
role
=
self
.
get_chat_request_role
(
request
)
role
=
self
.
get_chat_request_role
(
request
)
for
output
in
final_res
.
outputs
:
for
output
in
final_res
.
outputs
:
token_ids
=
output
.
token_ids
top_logprobs
=
output
.
logprobs
if
request
.
logprobs
:
logprobs
=
self
.
_create_logprobs
(
token_ids
=
token_ids
,
top_logprobs
=
top_logprobs
,
num_output_top_logprobs
=
request
.
logprobs
,
)
else
:
logprobs
=
None
choice_data
=
ChatCompletionResponseChoice
(
choice_data
=
ChatCompletionResponseChoice
(
index
=
output
.
index
,
index
=
output
.
index
,
message
=
ChatMessage
(
role
=
role
,
content
=
output
.
text
),
message
=
ChatMessage
(
role
=
role
,
content
=
output
.
text
),
logprobs
=
logprobs
,
finish_reason
=
output
.
finish_reason
,
finish_reason
=
output
.
finish_reason
,
)
)
choices
.
append
(
choice_data
)
choices
.
append
(
choice_data
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment