Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
95e7d4a9
Unverified
Commit
95e7d4a9
authored
Apr 11, 2024
by
Dylan Hawk
Committed by
GitHub
Apr 11, 2024
Browse files
Fix echo/logprob OpenAI completion bug (#3441)
Co-authored-by:
Dylan Hawk
<
dylanwawk@gmail.com
>
parent
559eb852
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
73 additions
and
29 deletions
+73
-29
tests/entrypoints/test_openai_server.py
tests/entrypoints/test_openai_server.py
+31
-0
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+5
-4
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+10
-5
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+27
-20
No files found.
tests/entrypoints/test_openai_server.py
View file @
95e7d4a9
...
@@ -742,5 +742,36 @@ number: "1" | "2"
...
@@ -742,5 +742,36 @@ number: "1" | "2"
assert
content
.
strip
()
==
ground_truth
assert
content
.
strip
()
==
ground_truth
@
pytest
.
mark
.
parametrize
(
# first test base model, then test loras
"model_name"
,
[
MODEL_NAME
,
"zephyr-lora"
,
"zephyr-lora2"
],
)
async
def
test_echo_logprob_completion
(
server
,
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
):
tokenizer
=
get_tokenizer
(
tokenizer_name
=
MODEL_NAME
)
# test using text and token IDs
for
prompt
in
(
"Hello, my name is"
,
[
0
,
0
,
0
,
0
,
0
]):
completion
=
await
client
.
completions
.
create
(
model
=
model_name
,
prompt
=
prompt
,
max_tokens
=
5
,
temperature
=
0.0
,
echo
=
True
,
logprobs
=
1
)
prompt_text
=
tokenizer
.
decode
(
prompt
)
if
isinstance
(
prompt
,
list
)
else
prompt
assert
(
completion
.
choices
[
0
].
text
is
not
None
and
re
.
search
(
r
"^"
+
prompt_text
,
completion
.
choices
[
0
].
text
))
logprobs
=
completion
.
choices
[
0
].
logprobs
assert
logprobs
is
not
None
assert
len
(
logprobs
.
text_offset
)
>
5
assert
(
len
(
logprobs
.
token_logprobs
)
>
5
and
logprobs
.
token_logprobs
[
0
]
is
None
)
assert
(
len
(
logprobs
.
top_logprobs
)
>
5
and
logprobs
.
top_logprobs
[
0
]
is
None
)
assert
len
(
logprobs
.
tokens
)
>
5
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
pytest
.
main
([
__file__
])
vllm/entrypoints/openai/serving_chat.py
View file @
95e7d4a9
...
@@ -63,8 +63,9 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -63,8 +63,9 @@ class OpenAIServingChat(OpenAIServing):
request_id
=
f
"cmpl-
{
random_uuid
()
}
"
request_id
=
f
"cmpl-
{
random_uuid
()
}
"
try
:
try
:
token_ids
=
self
.
_validate_prompt_and_tokenize
(
request
,
# Tokenize/detokenize depending on prompt format (string/token list)
prompt
=
prompt
)
prompt_ids
,
prompt_text
=
self
.
_validate_prompt_and_tokenize
(
request
,
prompt
=
prompt
)
sampling_params
=
request
.
to_sampling_params
()
sampling_params
=
request
.
to_sampling_params
()
lora_request
=
self
.
_maybe_get_lora
(
request
)
lora_request
=
self
.
_maybe_get_lora
(
request
)
guided_decode_logits_processor
=
(
guided_decode_logits_processor
=
(
...
@@ -78,8 +79,8 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -78,8 +79,8 @@ class OpenAIServingChat(OpenAIServing):
except
ValueError
as
e
:
except
ValueError
as
e
:
return
self
.
create_error_response
(
str
(
e
))
return
self
.
create_error_response
(
str
(
e
))
result_generator
=
self
.
engine
.
generate
(
prompt
,
sampling_params
,
result_generator
=
self
.
engine
.
generate
(
prompt
_text
,
sampling_params
,
request_id
,
token
_ids
,
request_id
,
prompt
_ids
,
lora_request
)
lora_request
)
# Streaming response
# Streaming response
if
request
.
stream
:
if
request
.
stream
:
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
95e7d4a9
...
@@ -136,23 +136,24 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -136,23 +136,24 @@ class OpenAIServingCompletion(OpenAIServing):
for
i
,
prompt
in
enumerate
(
prompts
):
for
i
,
prompt
in
enumerate
(
prompts
):
if
prompt_is_tokens
:
if
prompt_is_tokens
:
input_id
s
=
self
.
_validate_prompt_and_tokenize
(
prompt_format
s
=
self
.
_validate_prompt_and_tokenize
(
request
,
request
,
prompt_ids
=
prompt
,
prompt_ids
=
prompt
,
truncate_prompt_tokens
=
sampling_params
.
truncate_prompt_tokens
=
sampling_params
.
truncate_prompt_tokens
)
truncate_prompt_tokens
)
else
:
else
:
input_id
s
=
self
.
_validate_prompt_and_tokenize
(
prompt_format
s
=
self
.
_validate_prompt_and_tokenize
(
request
,
request
,
prompt
=
prompt
,
prompt
=
prompt
,
truncate_prompt_tokens
=
sampling_params
.
truncate_prompt_tokens
=
sampling_params
.
truncate_prompt_tokens
)
truncate_prompt_tokens
)
prompt_ids
,
prompt_text
=
prompt_formats
generators
.
append
(
generators
.
append
(
self
.
engine
.
generate
(
prompt
,
self
.
engine
.
generate
(
prompt
_text
,
sampling_params
,
sampling_params
,
f
"
{
request_id
}
-
{
i
}
"
,
f
"
{
request_id
}
-
{
i
}
"
,
prompt_token_ids
=
inpu
t_ids
,
prompt_token_ids
=
promp
t_ids
,
lora_request
=
lora_request
))
lora_request
=
lora_request
))
except
ValueError
as
e
:
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
# TODO: Use a vllm-specific Validation Error
...
@@ -326,7 +327,8 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -326,7 +327,8 @@ class OpenAIServingCompletion(OpenAIServing):
output_text
=
prompt_text
output_text
=
prompt_text
elif
request
.
echo
and
request
.
max_tokens
>
0
:
elif
request
.
echo
and
request
.
max_tokens
>
0
:
token_ids
=
prompt_token_ids
+
output
.
token_ids
token_ids
=
prompt_token_ids
+
output
.
token_ids
top_logprobs
=
prompt_logprobs
+
output
.
logprobs
top_logprobs
=
(
prompt_logprobs
+
output
.
logprobs
if
request
.
logprobs
else
None
)
output_text
=
prompt_text
+
output
.
text
output_text
=
prompt_text
+
output
.
text
else
:
else
:
token_ids
=
output
.
token_ids
token_ids
=
output
.
token_ids
...
@@ -334,6 +336,9 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -334,6 +336,9 @@ class OpenAIServingCompletion(OpenAIServing):
output_text
=
output
.
text
output_text
=
output
.
text
if
request
.
logprobs
is
not
None
:
if
request
.
logprobs
is
not
None
:
assert
top_logprobs
is
not
None
,
(
"top_logprobs must be provided when logprobs "
"is requested"
)
logprobs
=
self
.
_create_logprobs
(
logprobs
=
self
.
_create_logprobs
(
token_ids
=
token_ids
,
token_ids
=
token_ids
,
top_logprobs
=
top_logprobs
,
top_logprobs
=
top_logprobs
,
...
...
vllm/entrypoints/openai/serving_engine.py
View file @
95e7d4a9
...
@@ -2,7 +2,7 @@ import asyncio
...
@@ -2,7 +2,7 @@ import asyncio
import
json
import
json
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
http
import
HTTPStatus
from
http
import
HTTPStatus
from
typing
import
Dict
,
List
,
Optional
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
from
pydantic
import
conint
from
pydantic
import
conint
...
@@ -99,27 +99,32 @@ class OpenAIServing:
...
@@ -99,27 +99,32 @@ class OpenAIServing:
last_token_len
=
0
last_token_len
=
0
if
num_output_top_logprobs
:
if
num_output_top_logprobs
:
logprobs
.
top_logprobs
=
[]
logprobs
.
top_logprobs
=
[]
for
i
,
token_id
in
enumerate
(
token_ids
):
for
i
,
token_id
in
enumerate
(
token_ids
):
step_top_logprobs
=
top_logprobs
[
i
]
step_top_logprobs
=
top_logprobs
[
i
]
if
step_top_logprobs
is
not
None
:
if
step_top_logprobs
is
None
:
token_logprob
=
step_top_logprobs
[
token_id
].
logprob
token
=
self
.
tokenizer
.
decode
(
token_id
)
logprobs
.
tokens
.
append
(
token
)
logprobs
.
token_logprobs
.
append
(
None
)
logprobs
.
top_logprobs
.
append
(
None
)
else
:
else
:
token_logprob
=
None
token_logprob
=
step_top_logprobs
[
token_id
].
logprob
token
=
step_top_logprobs
[
token_id
].
decoded_token
token
=
step_top_logprobs
[
token_id
].
decoded_token
logprobs
.
tokens
.
append
(
token
)
logprobs
.
tokens
.
append
(
token
)
logprobs
.
token_logprobs
.
append
(
token_logprob
)
logprobs
.
token_logprobs
.
append
(
token_logprob
)
if
len
(
logprobs
.
text_offset
)
==
0
:
logprobs
.
text_offset
.
append
(
initial_text_offset
)
else
:
logprobs
.
text_offset
.
append
(
logprobs
.
text_offset
[
-
1
]
+
last_token_len
)
last_token_len
=
len
(
token
)
if
num_output_top_logprobs
:
if
num_output_top_logprobs
:
logprobs
.
top_logprobs
.
append
({
logprobs
.
top_logprobs
.
append
({
p
.
decoded_token
:
p
.
logprob
p
.
decoded_token
:
p
.
logprob
for
i
,
p
in
step_top_logprobs
.
items
()
for
i
,
p
in
step_top_logprobs
.
items
()
}
if
step_top_logprobs
else
None
)
}
if
step_top_logprobs
else
None
)
if
len
(
logprobs
.
text_offset
)
==
0
:
logprobs
.
text_offset
.
append
(
initial_text_offset
)
else
:
logprobs
.
text_offset
.
append
(
logprobs
.
text_offset
[
-
1
]
+
last_token_len
)
last_token_len
=
len
(
token
)
return
logprobs
return
logprobs
def
create_error_response
(
def
create_error_response
(
...
@@ -169,7 +174,7 @@ class OpenAIServing:
...
@@ -169,7 +174,7 @@ class OpenAIServing:
prompt
:
Optional
[
str
]
=
None
,
prompt
:
Optional
[
str
]
=
None
,
prompt_ids
:
Optional
[
List
[
int
]]
=
None
,
prompt_ids
:
Optional
[
List
[
int
]]
=
None
,
truncate_prompt_tokens
:
Optional
[
conint
(
ge
=
1
)]
=
None
truncate_prompt_tokens
:
Optional
[
conint
(
ge
=
1
)]
=
None
)
->
List
[
int
]:
)
->
Tuple
[
List
[
int
]
,
str
]
:
if
not
(
prompt
or
prompt_ids
):
if
not
(
prompt
or
prompt_ids
):
raise
ValueError
(
"Either prompt or prompt_ids should be provided."
)
raise
ValueError
(
"Either prompt or prompt_ids should be provided."
)
if
(
prompt
and
prompt_ids
):
if
(
prompt
and
prompt_ids
):
...
@@ -187,6 +192,8 @@ class OpenAIServing:
...
@@ -187,6 +192,8 @@ class OpenAIServing:
else
:
else
:
input_ids
=
prompt_ids
input_ids
=
prompt_ids
input_text
=
prompt
if
prompt
is
not
None
else
self
.
tokenizer
.
decode
(
prompt_ids
)
token_num
=
len
(
input_ids
)
token_num
=
len
(
input_ids
)
if
request
.
max_tokens
is
None
:
if
request
.
max_tokens
is
None
:
...
@@ -201,4 +208,4 @@ class OpenAIServing:
...
@@ -201,4 +208,4 @@ class OpenAIServing:
f
"
{
request
.
max_tokens
}
in the completion). "
f
"
{
request
.
max_tokens
}
in the completion). "
f
"Please reduce the length of the messages or completion."
,
)
f
"Please reduce the length of the messages or completion."
,
)
else
:
else
:
return
input_ids
return
input_ids
,
input_text
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment