Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f878c8fe
Unverified
Commit
f878c8fe
authored
Aug 16, 2024
by
Grant Pinkert
Committed by
GitHub
Aug 16, 2024
Browse files
[Feature]: Add OpenAI server prompt_logprobs support #6508 (#7453)
parent
b67ae00c
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
154 additions
and
3 deletions
+154
-3
tests/entrypoints/openai/test_completion.py
tests/entrypoints/openai/test_completion.py
+124
-1
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+9
-2
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+11
-0
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+10
-0
No files found.
tests/entrypoints/openai/test_completion.py
View file @
f878c8fe
...
@@ -3,7 +3,7 @@ import json
...
@@ -3,7 +3,7 @@ import json
import
re
import
re
import
shutil
import
shutil
from
tempfile
import
TemporaryDirectory
from
tempfile
import
TemporaryDirectory
from
typing
import
List
from
typing
import
Dict
,
List
import
jsonschema
import
jsonschema
import
openai
# use the official client for correctness check
import
openai
# use the official client for correctness check
...
@@ -130,6 +130,7 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str,
...
@@ -130,6 +130,7 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str,
temperature
=
0.0
,
temperature
=
0.0
,
)
)
assert
len
(
completion
.
choices
[
0
].
text
)
>=
1
assert
len
(
completion
.
choices
[
0
].
text
)
>=
1
assert
completion
.
choices
[
0
].
prompt_logprobs
is
None
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
...
@@ -267,6 +268,128 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
...
@@ -267,6 +268,128 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
assert
len
(
completion
.
choices
[
0
].
text
)
>=
0
assert
len
(
completion
.
choices
[
0
].
text
)
>=
0
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name, prompt_logprobs"
,
[(
MODEL_NAME
,
1
),
(
MODEL_NAME
,
0
),
(
MODEL_NAME
,
-
1
),
(
MODEL_NAME
,
None
)],
)
async
def
test_prompt_logprobs_chat
(
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
,
prompt_logprobs
:
int
):
params
:
Dict
=
{
"messages"
:
[{
"role"
:
"system"
,
"content"
:
"You are a helpful assistant."
},
{
"role"
:
"user"
,
"content"
:
"Who won the world series in 2020?"
},
{
"role"
:
"assistant"
,
"content"
:
"The Los Angeles Dodgers won the World Series in 2020."
},
{
"role"
:
"user"
,
"content"
:
"Where was it played?"
}],
"model"
:
model_name
}
if
prompt_logprobs
is
not
None
:
params
[
"extra_body"
]
=
{
"prompt_logprobs"
:
prompt_logprobs
}
if
prompt_logprobs
and
prompt_logprobs
<
0
:
with
pytest
.
raises
(
BadRequestError
)
as
err_info
:
await
client
.
chat
.
completions
.
create
(
**
params
)
expected_err_string
=
(
"Error code: 400 - {'object': 'error', 'message': "
"'Prompt_logprobs set to invalid negative value: -1',"
" 'type': 'BadRequestError', 'param': None, 'code': 400}"
)
assert
str
(
err_info
.
value
)
==
expected_err_string
else
:
completion
=
await
client
.
chat
.
completions
.
create
(
**
params
)
if
prompt_logprobs
and
prompt_logprobs
>
0
:
assert
completion
.
prompt_logprobs
is
not
None
assert
len
(
completion
.
prompt_logprobs
)
>
0
else
:
assert
completion
.
prompt_logprobs
is
None
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
],
)
async
def
test_more_than_one_prompt_logprobs_chat
(
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
):
params
:
Dict
=
{
"messages"
:
[{
"role"
:
"system"
,
"content"
:
"You are a helpful assistant."
},
{
"role"
:
"user"
,
"content"
:
"Who won the world series in 2020?"
},
{
"role"
:
"assistant"
,
"content"
:
"The Los Angeles Dodgers won the World Series in 2020."
},
{
"role"
:
"user"
,
"content"
:
"Where was it played?"
}],
"model"
:
model_name
,
"extra_body"
:
{
"prompt_logprobs"
:
1
}
}
completion_1
=
await
client
.
chat
.
completions
.
create
(
**
params
)
params
[
"extra_body"
]
=
{
"prompt_logprobs"
:
2
}
completion_2
=
await
client
.
chat
.
completions
.
create
(
**
params
)
assert
len
(
completion_1
.
prompt_logprobs
[
3
])
==
1
assert
len
(
completion_2
.
prompt_logprobs
[
3
])
==
2
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name, prompt_logprobs"
,
[(
MODEL_NAME
,
-
1
),
(
MODEL_NAME
,
0
),
(
MODEL_NAME
,
1
),
(
MODEL_NAME
,
None
)])
async
def
test_prompt_logprobs_completion
(
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
,
prompt_logprobs
:
int
):
params
:
Dict
=
{
"prompt"
:
[
"A robot may not injure another robot"
,
"My name is"
],
"model"
:
model_name
,
}
if
prompt_logprobs
is
not
None
:
params
[
"extra_body"
]
=
{
"prompt_logprobs"
:
prompt_logprobs
}
if
prompt_logprobs
and
prompt_logprobs
<
0
:
with
pytest
.
raises
(
BadRequestError
)
as
err_info
:
await
client
.
completions
.
create
(
**
params
)
expected_err_string
=
(
"Error code: 400 - {'object': 'error', 'message': "
"'Prompt_logprobs set to invalid negative value: -1',"
" 'type': 'BadRequestError', 'param': None, 'code': 400}"
)
assert
str
(
err_info
.
value
)
==
expected_err_string
else
:
completion
=
await
client
.
completions
.
create
(
**
params
)
if
prompt_logprobs
and
prompt_logprobs
>
0
:
assert
completion
.
choices
[
0
].
prompt_logprobs
is
not
None
assert
len
(
completion
.
choices
[
0
].
prompt_logprobs
)
>
0
assert
completion
.
choices
[
1
].
prompt_logprobs
is
not
None
assert
len
(
completion
.
choices
[
1
].
prompt_logprobs
)
>
0
else
:
assert
completion
.
choices
[
0
].
prompt_logprobs
is
None
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"model_name"
,
"model_name"
,
...
...
vllm/entrypoints/openai/protocol.py
View file @
f878c8fe
...
@@ -13,6 +13,7 @@ from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
...
@@ -13,6 +13,7 @@ from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from
vllm.entrypoints.openai.logits_processors
import
get_logits_processors
from
vllm.entrypoints.openai.logits_processors
import
get_logits_processors
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
LogitsProcessor
,
SamplingParams
from
vllm.sampling_params
import
LogitsProcessor
,
SamplingParams
from
vllm.sequence
import
Logprob
from
vllm.utils
import
random_uuid
from
vllm.utils
import
random_uuid
# torch is mocked during docs generation,
# torch is mocked during docs generation,
...
@@ -152,6 +153,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
...
@@ -152,6 +153,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
skip_special_tokens
:
bool
=
True
skip_special_tokens
:
bool
=
True
spaces_between_special_tokens
:
bool
=
True
spaces_between_special_tokens
:
bool
=
True
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=
1
)]]
=
None
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=
1
)]]
=
None
prompt_logprobs
:
Optional
[
int
]
=
None
# doc: end-chat-completion-sampling-params
# doc: end-chat-completion-sampling-params
# doc: begin-chat-completion-extra-params
# doc: begin-chat-completion-extra-params
...
@@ -263,7 +265,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
...
@@ -263,7 +265,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
stop
=
self
.
stop
,
stop
=
self
.
stop
,
stop_token_ids
=
self
.
stop_token_ids
,
stop_token_ids
=
self
.
stop_token_ids
,
logprobs
=
self
.
top_logprobs
if
self
.
logprobs
else
None
,
logprobs
=
self
.
top_logprobs
if
self
.
logprobs
else
None
,
prompt_logprobs
=
self
.
top_logprobs
if
self
.
echo
else
None
,
prompt_logprobs
=
self
.
prompt_logprobs
if
self
.
prompt_logprobs
else
(
self
.
top_logprobs
if
self
.
echo
else
None
),
ignore_eos
=
self
.
ignore_eos
,
ignore_eos
=
self
.
ignore_eos
,
max_tokens
=
max_tokens
,
max_tokens
=
max_tokens
,
min_tokens
=
self
.
min_tokens
,
min_tokens
=
self
.
min_tokens
,
...
@@ -368,6 +371,7 @@ class CompletionRequest(OpenAIBaseModel):
...
@@ -368,6 +371,7 @@ class CompletionRequest(OpenAIBaseModel):
spaces_between_special_tokens
:
bool
=
True
spaces_between_special_tokens
:
bool
=
True
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=
1
)]]
=
None
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=
1
)]]
=
None
allowed_token_ids
:
Optional
[
List
[
int
]]
=
None
allowed_token_ids
:
Optional
[
List
[
int
]]
=
None
prompt_logprobs
:
Optional
[
int
]
=
None
# doc: end-completion-sampling-params
# doc: end-completion-sampling-params
# doc: begin-completion-extra-params
# doc: begin-completion-extra-params
...
@@ -454,7 +458,8 @@ class CompletionRequest(OpenAIBaseModel):
...
@@ -454,7 +458,8 @@ class CompletionRequest(OpenAIBaseModel):
min_tokens
=
self
.
min_tokens
,
min_tokens
=
self
.
min_tokens
,
use_beam_search
=
self
.
use_beam_search
,
use_beam_search
=
self
.
use_beam_search
,
early_stopping
=
self
.
early_stopping
,
early_stopping
=
self
.
early_stopping
,
prompt_logprobs
=
self
.
logprobs
if
self
.
echo
else
None
,
prompt_logprobs
=
self
.
prompt_logprobs
if
self
.
prompt_logprobs
else
self
.
logprobs
if
self
.
echo
else
None
,
skip_special_tokens
=
self
.
skip_special_tokens
,
skip_special_tokens
=
self
.
skip_special_tokens
,
spaces_between_special_tokens
=
self
.
spaces_between_special_tokens
,
spaces_between_special_tokens
=
self
.
spaces_between_special_tokens
,
include_stop_str_in_output
=
self
.
include_stop_str_in_output
,
include_stop_str_in_output
=
self
.
include_stop_str_in_output
,
...
@@ -532,6 +537,7 @@ class CompletionResponseChoice(OpenAIBaseModel):
...
@@ -532,6 +537,7 @@ class CompletionResponseChoice(OpenAIBaseModel):
"to stop, None if the completion finished for some other reason "
"to stop, None if the completion finished for some other reason "
"including encountering the EOS token"
),
"including encountering the EOS token"
),
)
)
prompt_logprobs
:
Optional
[
List
[
Optional
[
Dict
[
int
,
Logprob
]]]]
=
None
class
CompletionResponse
(
OpenAIBaseModel
):
class
CompletionResponse
(
OpenAIBaseModel
):
...
@@ -627,6 +633,7 @@ class ChatCompletionResponse(OpenAIBaseModel):
...
@@ -627,6 +633,7 @@ class ChatCompletionResponse(OpenAIBaseModel):
model
:
str
model
:
str
choices
:
List
[
ChatCompletionResponseChoice
]
choices
:
List
[
ChatCompletionResponseChoice
]
usage
:
UsageInfo
usage
:
UsageInfo
prompt_logprobs
:
Optional
[
List
[
Optional
[
Dict
[
int
,
Logprob
]]]]
=
None
class
DeltaMessage
(
OpenAIBaseModel
):
class
DeltaMessage
(
OpenAIBaseModel
):
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
f878c8fe
...
@@ -83,6 +83,16 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -83,6 +83,16 @@ class OpenAIServingChat(OpenAIServing):
if
error_check_ret
is
not
None
:
if
error_check_ret
is
not
None
:
return
error_check_ret
return
error_check_ret
if
request
.
prompt_logprobs
is
not
None
:
if
request
.
stream
and
request
.
prompt_logprobs
>
0
:
return
self
.
create_error_response
(
"Prompt_logprobs are not available when stream is enabled"
)
if
request
.
prompt_logprobs
<
0
:
return
self
.
create_error_response
(
f
"Prompt_logprobs set to invalid "
f
"negative value:
{
request
.
prompt_logprobs
}
"
)
try
:
try
:
(
(
lora_request
,
lora_request
,
...
@@ -506,6 +516,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -506,6 +516,7 @@ class OpenAIServingChat(OpenAIServing):
model
=
model_name
,
model
=
model_name
,
choices
=
choices
,
choices
=
choices
,
usage
=
usage
,
usage
=
usage
,
prompt_logprobs
=
final_res
.
prompt_logprobs
,
)
)
return
response
return
response
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
f878c8fe
...
@@ -84,6 +84,15 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -84,6 +84,15 @@ class OpenAIServingCompletion(OpenAIServing):
request_id
=
f
"cmpl-
{
random_uuid
()
}
"
request_id
=
f
"cmpl-
{
random_uuid
()
}
"
created_time
=
int
(
time
.
time
())
created_time
=
int
(
time
.
time
())
if
request
.
prompt_logprobs
is
not
None
:
if
request
.
stream
and
request
.
prompt_logprobs
>
0
:
return
self
.
create_error_response
(
"Prompt_logprobs are not available when stream is enabled"
)
elif
request
.
prompt_logprobs
<
0
:
return
self
.
create_error_response
(
f
"Prompt_logprobs set to invalid negative "
f
"value:
{
request
.
prompt_logprobs
}
"
)
# Schedule the request and get the result generator.
# Schedule the request and get the result generator.
generators
:
List
[
AsyncGenerator
[
RequestOutput
,
None
]]
=
[]
generators
:
List
[
AsyncGenerator
[
RequestOutput
,
None
]]
=
[]
try
:
try
:
...
@@ -377,6 +386,7 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -377,6 +386,7 @@ class OpenAIServingCompletion(OpenAIServing):
logprobs
=
logprobs
,
logprobs
=
logprobs
,
finish_reason
=
output
.
finish_reason
,
finish_reason
=
output
.
finish_reason
,
stop_reason
=
output
.
stop_reason
,
stop_reason
=
output
.
stop_reason
,
prompt_logprobs
=
final_res
.
prompt_logprobs
,
)
)
choices
.
append
(
choice_data
)
choices
.
append
(
choice_data
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment