Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
87d41c84
Unverified
Commit
87d41c84
authored
May 30, 2024
by
Breno Faria
Committed by
GitHub
May 30, 2024
Browse files
[BUGFIX] [FRONTEND] Correct chat logprobs (#5029)
Co-authored-by:
Breno Faria
<
breno.faria@intrafind.com
>
parent
e07aff9e
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
361 additions
and
98 deletions
+361
-98
tests/async_engine/test_openapi_server_ray.py
tests/async_engine/test_openapi_server_ray.py
+4
-2
tests/entrypoints/test_openai_server.py
tests/entrypoints/test_openai_server.py
+181
-28
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+43
-7
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+59
-9
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+68
-6
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+6
-46
No files found.
tests/async_engine/test_openapi_server_ray.py
View file @
87d41c84
...
...
@@ -94,8 +94,10 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI):
chat_completion
.
choices
)
==
1
assert
chat_completion
.
choices
[
0
].
message
is
not
None
assert
chat_completion
.
choices
[
0
].
logprobs
is
not
None
assert
chat_completion
.
choices
[
0
].
logprobs
.
top_logprobs
is
not
None
assert
len
(
chat_completion
.
choices
[
0
].
logprobs
.
top_logprobs
[
0
])
==
5
assert
chat_completion
.
choices
[
0
].
logprobs
.
content
[
0
].
top_logprobs
is
not
None
assert
len
(
chat_completion
.
choices
[
0
].
logprobs
.
content
[
0
].
top_logprobs
)
==
5
message
=
chat_completion
.
choices
[
0
].
message
assert
message
.
content
is
not
None
and
len
(
message
.
content
)
>=
10
assert
message
.
role
==
"assistant"
...
...
tests/entrypoints/test_openai_server.py
View file @
87d41c84
...
...
@@ -184,6 +184,26 @@ async def test_single_completion(server, client: openai.AsyncOpenAI,
completion
.
choices
[
0
].
text
)
>=
5
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
# first test base model, then test loras
"model_name"
,
[
MODEL_NAME
,
"zephyr-lora"
,
"zephyr-lora2"
],
)
async
def
test_no_logprobs
(
server
,
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
):
# test using token IDs
completion
=
await
client
.
completions
.
create
(
model
=
MODEL_NAME
,
prompt
=
[
0
,
0
,
0
,
0
,
0
],
max_tokens
=
5
,
temperature
=
0.0
,
logprobs
=
None
,
)
choice
=
completion
.
choices
[
0
]
assert
choice
.
logprobs
is
None
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
# first test base model, then test loras
...
...
@@ -203,7 +223,72 @@ async def test_zero_logprobs(server, client: openai.AsyncOpenAI,
choice
=
completion
.
choices
[
0
]
assert
choice
.
logprobs
is
not
None
assert
choice
.
logprobs
.
token_logprobs
is
not
None
assert
choice
.
logprobs
.
top_logprobs
is
None
assert
choice
.
logprobs
.
top_logprobs
is
not
None
assert
len
(
choice
.
logprobs
.
top_logprobs
[
0
])
<=
1
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
,
"zephyr-lora"
],
)
async
def
test_some_logprobs
(
server
,
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
):
# test using token IDs
completion
=
await
client
.
completions
.
create
(
model
=
MODEL_NAME
,
prompt
=
[
0
,
0
,
0
,
0
,
0
],
max_tokens
=
5
,
temperature
=
0.0
,
logprobs
=
5
,
)
choice
=
completion
.
choices
[
0
]
assert
choice
.
logprobs
is
not
None
assert
choice
.
logprobs
.
token_logprobs
is
not
None
assert
choice
.
logprobs
.
top_logprobs
is
not
None
assert
len
(
choice
.
logprobs
.
top_logprobs
[
0
])
<=
6
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
,
"zephyr-lora"
],
)
async
def
test_too_many_completion_logprobs
(
server
,
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
):
with
pytest
.
raises
(
(
openai
.
BadRequestError
,
openai
.
APIError
)):
# test using token IDs
await
client
.
completions
.
create
(
model
=
MODEL_NAME
,
prompt
=
[
0
,
0
,
0
,
0
,
0
],
max_tokens
=
5
,
temperature
=
0.0
,
logprobs
=
6
,
)
...
with
pytest
.
raises
(
(
openai
.
BadRequestError
,
openai
.
APIError
)):
# test using token IDs
stream
=
await
client
.
completions
.
create
(
model
=
MODEL_NAME
,
prompt
=
[
0
,
0
,
0
,
0
,
0
],
max_tokens
=
5
,
temperature
=
0.0
,
logprobs
=
6
,
stream
=
True
,
)
async
for
chunk
in
stream
:
...
# the server should still work afterwards
completion
=
await
client
.
completions
.
create
(
model
=
model_name
,
prompt
=
[
0
,
0
,
0
,
0
,
0
],
max_tokens
=
5
,
temperature
=
0.0
,
)
completion
=
completion
.
choices
[
0
].
text
assert
completion
is
not
None
and
len
(
completion
)
>=
0
@
pytest
.
mark
.
asyncio
...
...
@@ -233,8 +318,10 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
chat_completion
.
choices
)
==
1
assert
chat_completion
.
choices
[
0
].
message
is
not
None
assert
chat_completion
.
choices
[
0
].
logprobs
is
not
None
assert
chat_completion
.
choices
[
0
].
logprobs
.
top_logprobs
is
not
None
assert
len
(
chat_completion
.
choices
[
0
].
logprobs
.
top_logprobs
[
0
])
==
5
assert
chat_completion
.
choices
[
0
].
logprobs
.
content
[
0
].
top_logprobs
is
not
None
assert
len
(
chat_completion
.
choices
[
0
].
logprobs
.
content
[
0
].
top_logprobs
)
==
5
message
=
chat_completion
.
choices
[
0
].
message
assert
message
.
content
is
not
None
and
len
(
message
.
content
)
>=
10
assert
message
.
role
==
"assistant"
...
...
@@ -251,10 +338,93 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
assert
message
.
content
is
not
None
and
len
(
message
.
content
)
>=
0
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
# first test base model, then test loras
"model_name"
,
[
MODEL_NAME
,
"zephyr-lora"
,
"zephyr-lora2"
],
)
async
def
test_no_logprobs_chat
(
server
,
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
):
messages
=
[{
"role"
:
"system"
,
"content"
:
"you are a helpful assistant"
},
{
"role"
:
"user"
,
"content"
:
"what is 1+1?"
}]
chat_completion
=
await
client
.
chat
.
completions
.
create
(
model
=
model_name
,
messages
=
messages
,
max_tokens
=
5
,
temperature
=
0.0
,
logprobs
=
False
)
choice
=
chat_completion
.
choices
[
0
]
assert
choice
.
logprobs
is
None
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
# just test 1 lora hereafter
"model_name"
,
[
MODEL_NAME
,
"zephyr-lora"
],
)
async
def
test_zero_logprobs_chat
(
server
,
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
):
messages
=
[{
"role"
:
"system"
,
"content"
:
"you are a helpful assistant"
},
{
"role"
:
"user"
,
"content"
:
"what is 1+1?"
}]
chat_completion
=
await
client
.
chat
.
completions
.
create
(
model
=
model_name
,
messages
=
messages
,
max_tokens
=
5
,
temperature
=
0.0
,
logprobs
=
True
,
top_logprobs
=
0
)
choice
=
chat_completion
.
choices
[
0
]
assert
choice
.
logprobs
is
not
None
assert
choice
.
logprobs
.
content
is
not
None
assert
len
(
choice
.
logprobs
.
content
[
0
].
top_logprobs
)
<=
1
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
,
"zephyr-lora"
],
)
async
def
test_some_logprobs_chat
(
server
,
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
):
messages
=
[{
"role"
:
"system"
,
"content"
:
"you are a helpful assistant"
},
{
"role"
:
"user"
,
"content"
:
"what is 1+1?"
}]
chat_completion
=
await
client
.
chat
.
completions
.
create
(
model
=
model_name
,
messages
=
messages
,
max_tokens
=
5
,
temperature
=
0.0
,
logprobs
=
True
,
top_logprobs
=
5
)
choice
=
chat_completion
.
choices
[
0
]
assert
choice
.
logprobs
is
not
None
assert
choice
.
logprobs
.
content
is
not
None
assert
len
(
choice
.
logprobs
.
content
[
0
].
top_logprobs
)
<=
6
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
async
def
test_too_many_logprobs
(
server
,
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
):
async
def
test_too_many_
chat_
logprobs
(
server
,
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
):
messages
=
[{
"role"
:
"system"
,
"content"
:
"you are a helpful assistant"
...
...
@@ -263,13 +433,13 @@ async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
"content"
:
"what is 1+1?"
}]
# Default max_logprobs is
5
, so this should raise an error
# Default max_logprobs is
20
, so this should raise an error
with
pytest
.
raises
((
openai
.
BadRequestError
,
openai
.
APIError
)):
stream
=
await
client
.
chat
.
completions
.
create
(
model
=
model_name
,
messages
=
messages
,
max_tokens
=
10
,
logprobs
=
True
,
top_logprobs
=
1
0
,
top_logprobs
=
2
1
,
stream
=
True
)
async
for
chunk
in
stream
:
...
...
...
@@ -279,25 +449,9 @@ async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
messages
=
messages
,
max_tokens
=
10
,
logprobs
=
True
,
top_logprobs
=
1
0
,
top_logprobs
=
3
0
,
stream
=
False
)
with
pytest
.
raises
((
openai
.
BadRequestError
,
openai
.
APIError
)):
stream
=
await
client
.
completions
.
create
(
model
=
model_name
,
prompt
=
"Test"
,
max_tokens
=
10
,
logprobs
=
10
,
stream
=
True
)
async
for
chunk
in
stream
:
...
with
pytest
.
raises
(
openai
.
BadRequestError
):
await
client
.
completions
.
create
(
model
=
model_name
,
prompt
=
"Test"
,
max_tokens
=
10
,
logprobs
=
10
,
stream
=
False
)
# the server should still work afterwards
chat_completion
=
await
client
.
chat
.
completions
.
create
(
model
=
model_name
,
messages
=
messages
,
...
...
@@ -744,13 +898,12 @@ async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI,
top_logprobs
=
5
,
extra_body
=
dict
(
guided_choice
=
TEST_CHOICE
,
guided_decoding_backend
=
guided_decoding_backend
))
top_logprobs
=
chat_completion
.
choices
[
0
].
logprobs
.
top_logprobs
top_logprobs
=
chat_completion
.
choices
[
0
].
logprobs
.
content
[
0
].
top_logprobs
# -9999.0 is the minimum logprob returned by OpenAI
assert
all
(
isinstance
(
logprob
,
float
)
and
logprob
>=
-
9999.0
for
token_dict
in
top_logprobs
for
token
,
logprob
in
token_dict
.
items
())
isinstance
(
token
.
logprob
,
float
)
and
token
.
logprob
>=
-
9999.0
for
token
in
top_logprobs
)
@
pytest
.
mark
.
asyncio
...
...
vllm/entrypoints/openai/protocol.py
View file @
87d41c84
...
...
@@ -250,6 +250,19 @@ class ChatCompletionRequest(OpenAIBaseModel):
"('guided_json', 'guided_regex' or 'guided_choice')."
)
return
data
@
model_validator
(
mode
=
"before"
)
@
classmethod
def
check_logprobs
(
cls
,
data
):
if
"top_logprobs"
in
data
and
data
[
"top_logprobs"
]
is
not
None
:
if
"logprobs"
not
in
data
or
data
[
"logprobs"
]
is
False
:
raise
ValueError
(
"when using `top_logprobs`, `logprobs` must be set to true."
)
elif
not
0
<=
data
[
"top_logprobs"
]
<=
20
:
raise
ValueError
(
"`top_logprobs` must be a value in the interval [0, 20]."
)
return
data
class
CompletionRequest
(
OpenAIBaseModel
):
# Ordered by official OpenAI API documentation
...
...
@@ -396,6 +409,15 @@ class CompletionRequest(OpenAIBaseModel):
"('guided_json', 'guided_regex' or 'guided_choice')."
)
return
data
@
model_validator
(
mode
=
"before"
)
@
classmethod
def
check_logprobs
(
cls
,
data
):
if
"logprobs"
in
data
and
data
[
"logprobs"
]
is
not
None
and
not
0
<=
data
[
"logprobs"
]
<=
5
:
raise
ValueError
((
"if passed, `logprobs` must be a value"
,
" in the interval [0, 5]."
))
return
data
class
EmbeddingRequest
(
BaseModel
):
# Ordered by official OpenAI API documentation
...
...
@@ -415,7 +437,7 @@ class EmbeddingRequest(BaseModel):
return
PoolingParams
(
additional_data
=
self
.
additional_data
)
class
LogProbs
(
OpenAIBaseModel
):
class
Completion
LogProbs
(
OpenAIBaseModel
):
text_offset
:
List
[
int
]
=
Field
(
default_factory
=
list
)
token_logprobs
:
List
[
Optional
[
float
]]
=
Field
(
default_factory
=
list
)
tokens
:
List
[
str
]
=
Field
(
default_factory
=
list
)
...
...
@@ -425,7 +447,7 @@ class LogProbs(OpenAIBaseModel):
class
CompletionResponseChoice
(
OpenAIBaseModel
):
index
:
int
text
:
str
logprobs
:
Optional
[
LogProbs
]
=
None
logprobs
:
Optional
[
Completion
LogProbs
]
=
None
finish_reason
:
Optional
[
str
]
=
None
stop_reason
:
Optional
[
Union
[
int
,
str
]]
=
Field
(
default
=
None
,
...
...
@@ -448,7 +470,7 @@ class CompletionResponse(OpenAIBaseModel):
class
CompletionResponseStreamChoice
(
OpenAIBaseModel
):
index
:
int
text
:
str
logprobs
:
Optional
[
LogProbs
]
=
None
logprobs
:
Optional
[
Completion
LogProbs
]
=
None
finish_reason
:
Optional
[
str
]
=
None
stop_reason
:
Optional
[
Union
[
int
,
str
]]
=
Field
(
default
=
None
,
...
...
@@ -488,11 +510,25 @@ class ChatMessage(OpenAIBaseModel):
content
:
str
class
ChatCompletionLogProb
(
OpenAIBaseModel
):
token
:
str
logprob
:
float
=
-
9999.0
bytes
:
Optional
[
List
[
int
]]
=
None
class
ChatCompletionLogProbsContent
(
ChatCompletionLogProb
):
top_logprobs
:
List
[
ChatCompletionLogProb
]
=
Field
(
default_factory
=
list
)
class
ChatCompletionLogProbs
(
OpenAIBaseModel
):
content
:
Optional
[
List
[
ChatCompletionLogProbsContent
]]
=
None
class
ChatCompletionResponseChoice
(
OpenAIBaseModel
):
index
:
int
message
:
ChatMessage
logprobs
:
Optional
[
LogProbs
]
=
None
finish_reason
:
Optional
[
str
]
=
None
logprobs
:
Optional
[
ChatCompletion
LogProbs
]
=
None
finish_reason
:
Optional
[
Literal
[
"stop"
,
"length"
,
"tool_calls"
]
]
=
None
stop_reason
:
Optional
[
Union
[
int
,
str
]]
=
None
...
...
@@ -513,8 +549,8 @@ class DeltaMessage(OpenAIBaseModel):
class
ChatCompletionResponseStreamChoice
(
OpenAIBaseModel
):
index
:
int
delta
:
DeltaMessage
logprobs
:
Optional
[
LogProbs
]
=
None
finish_reason
:
Optional
[
str
]
=
None
logprobs
:
Optional
[
ChatCompletion
LogProbs
]
=
None
finish_reason
:
Optional
[
Literal
[
"stop"
,
"length"
,
"tool_calls"
]
]
=
None
stop_reason
:
Optional
[
Union
[
int
,
str
]]
=
None
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
87d41c84
import
codecs
import
time
from
dataclasses
import
dataclass
from
typing
import
(
AsyncGenerator
,
AsyncIterator
,
Iterable
,
List
,
Optional
,
TypedDict
,
Union
,
cast
,
final
)
from
typing
import
(
AsyncGenerator
,
AsyncIterator
,
Dict
,
Iterable
,
List
,
Optional
)
from
typing
import
Sequence
as
GenericSequence
from
typing
import
TypedDict
,
Union
,
cast
,
final
from
fastapi
import
Request
from
openai.types.chat
import
ChatCompletionContentPartTextParam
...
...
@@ -10,8 +12,9 @@ from openai.types.chat import ChatCompletionContentPartTextParam
from
vllm.config
import
ModelConfig
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionContentPartParam
,
ChatCompletionMessageParam
,
ChatCompletionRequest
,
ChatCompletionResponse
,
ChatCompletionContentPartParam
,
ChatCompletionLogProb
,
ChatCompletionLogProbs
,
ChatCompletionLogProbsContent
,
ChatCompletionMessageParam
,
ChatCompletionRequest
,
ChatCompletionResponse
,
ChatCompletionResponseChoice
,
ChatCompletionResponseStreamChoice
,
ChatCompletionStreamResponse
,
ChatMessage
,
DeltaMessage
,
ErrorResponse
,
UsageInfo
)
...
...
@@ -21,6 +24,7 @@ from vllm.logger import init_logger
from
vllm.model_executor.guided_decoding
import
(
get_guided_decoding_logits_processor
)
from
vllm.outputs
import
RequestOutput
from
vllm.sequence
import
Logprob
from
vllm.utils
import
random_uuid
logger
=
init_logger
(
__name__
)
...
...
@@ -283,11 +287,10 @@ class OpenAIServingChat(OpenAIServing):
previous_num_tokens
[
i
]:]
if
output
.
logprobs
else
None
if
request
.
logprobs
:
logprobs
=
self
.
_create_logprobs
(
logprobs
=
self
.
_create_
chat_
logprobs
(
token_ids
=
delta_token_ids
,
top_logprobs
=
top_logprobs
,
num_output_top_logprobs
=
request
.
top_logprobs
,
initial_text_offset
=
len
(
previous_texts
[
i
]),
)
else
:
logprobs
=
None
...
...
@@ -370,7 +373,7 @@ class OpenAIServingChat(OpenAIServing):
top_logprobs
=
output
.
logprobs
if
request
.
logprobs
:
logprobs
=
self
.
_create_logprobs
(
logprobs
=
self
.
_create_
chat_
logprobs
(
token_ids
=
token_ids
,
top_logprobs
=
top_logprobs
,
num_output_top_logprobs
=
request
.
top_logprobs
,
...
...
@@ -383,8 +386,7 @@ class OpenAIServingChat(OpenAIServing):
message
=
ChatMessage
(
role
=
role
,
content
=
output
.
text
),
logprobs
=
logprobs
,
finish_reason
=
output
.
finish_reason
,
stop_reason
=
output
.
stop_reason
,
)
stop_reason
=
output
.
stop_reason
)
choices
.
append
(
choice_data
)
if
request
.
echo
:
...
...
@@ -414,3 +416,51 @@ class OpenAIServingChat(OpenAIServing):
)
return
response
def
_get_top_logprobs
(
self
,
logprobs
:
Dict
[
int
,
Logprob
],
top_logprobs
:
Optional
[
int
])
->
List
[
ChatCompletionLogProb
]:
return
[
ChatCompletionLogProb
(
token
=
self
.
_get_decoded_token
(
p
[
1
],
p
[
0
]),
logprob
=
max
(
p
[
1
].
logprob
,
-
9999.0
),
bytes
=
list
(
self
.
_get_decoded_token
(
p
[
1
],
p
[
0
]).
encode
(
"utf-8"
,
errors
=
"replace"
)))
for
i
,
p
in
enumerate
(
logprobs
.
items
())
if
top_logprobs
and
i
<
top_logprobs
]
def
_create_chat_logprobs
(
self
,
token_ids
:
GenericSequence
[
int
],
top_logprobs
:
GenericSequence
[
Optional
[
Dict
[
int
,
Logprob
]]],
num_output_top_logprobs
:
Optional
[
int
]
=
None
,
)
->
ChatCompletionLogProbs
:
"""Create OpenAI-style logprobs."""
logprobs_content
=
[]
for
i
,
token_id
in
enumerate
(
token_ids
):
step_top_logprobs
=
top_logprobs
[
i
]
if
step_top_logprobs
is
None
:
logprobs_content
.
append
(
ChatCompletionLogProbsContent
(
token
=
self
.
tokenizer
.
decode
(
token_id
),
bytes
=
list
(
self
.
tokenizer
.
decode
(
token_id
).
encode
(
"utf-8"
,
errors
=
"replace"
))))
else
:
logprobs_content
.
append
(
ChatCompletionLogProbsContent
(
token
=
step_top_logprobs
[
token_id
].
decoded_token
,
logprob
=
max
(
step_top_logprobs
[
token_id
].
logprob
,
-
9999.0
),
bytes
=
list
(
step_top_logprobs
[
token_id
].
decoded_token
.
encode
(
"utf-8"
,
errors
=
"replace"
)),
top_logprobs
=
self
.
_get_top_logprobs
(
step_top_logprobs
,
num_output_top_logprobs
)))
return
ChatCompletionLogProbs
(
content
=
logprobs_content
)
vllm/entrypoints/openai/serving_completion.py
View file @
87d41c84
import
time
from
typing
import
(
AsyncGenerator
,
AsyncIterator
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
)
Optional
)
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Tuple
from
fastapi
import
Request
from
vllm.config
import
ModelConfig
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.openai.protocol
import
(
CompletionRequest
,
# yapf: disable
from
vllm.entrypoints.openai.protocol
import
(
CompletionLogProbs
,
CompletionRequest
,
CompletionResponse
,
CompletionResponseChoice
,
CompletionResponseStreamChoice
,
CompletionStreamResponse
,
LogProbs
,
UsageInfo
)
UsageInfo
)
# yapf: enable
from
vllm.entrypoints.openai.serving_engine
import
(
LoRAModulePath
,
OpenAIServing
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.guided_decoding
import
(
get_guided_decoding_logits_processor
)
from
vllm.outputs
import
RequestOutput
from
vllm.sequence
import
Logprob
from
vllm.utils
import
merge_async_iterators
,
random_uuid
logger
=
init_logger
(
__name__
)
...
...
@@ -25,7 +31,7 @@ logger = init_logger(__name__)
TypeTokenIDs
=
List
[
int
]
TypeTopLogProbs
=
List
[
Optional
[
Dict
[
int
,
float
]]]
TypeCreateLogProbsFn
=
Callable
[
[
TypeTokenIDs
,
TypeTopLogProbs
,
Optional
[
int
],
int
],
LogProbs
]
[
TypeTokenIDs
,
TypeTopLogProbs
,
Optional
[
int
],
int
],
Completion
LogProbs
]
def
parse_prompt_format
(
prompt
)
->
Tuple
[
bool
,
list
]:
...
...
@@ -235,7 +241,7 @@ class OpenAIServingCompletion(OpenAIServing):
i
]:]
if
output
.
logprobs
else
None
if
request
.
logprobs
is
not
None
:
logprobs
=
self
.
_create_logprobs
(
logprobs
=
self
.
_create_
completion_
logprobs
(
token_ids
=
delta_token_ids
,
top_logprobs
=
top_logprobs
,
num_output_top_logprobs
=
request
.
logprobs
,
...
...
@@ -317,7 +323,7 @@ class OpenAIServingCompletion(OpenAIServing):
assert
top_logprobs
is
not
None
,
(
"top_logprobs must be provided when logprobs "
"is requested"
)
logprobs
=
self
.
_create_logprobs
(
logprobs
=
self
.
_create_
completion_
logprobs
(
token_ids
=
token_ids
,
top_logprobs
=
top_logprobs
,
num_output_top_logprobs
=
request
.
logprobs
,
...
...
@@ -351,3 +357,59 @@ class OpenAIServingCompletion(OpenAIServing):
choices
=
choices
,
usage
=
usage
,
)
def
_create_completion_logprobs
(
self
,
token_ids
:
GenericSequence
[
int
],
top_logprobs
:
GenericSequence
[
Optional
[
Dict
[
int
,
Logprob
]]],
num_output_top_logprobs
:
int
,
initial_text_offset
:
int
=
0
,
)
->
CompletionLogProbs
:
"""Create logprobs for OpenAI Completion API."""
out_text_offset
:
List
[
int
]
=
[]
out_token_logprobs
:
List
[
Optional
[
float
]]
=
[]
out_tokens
:
List
[
str
]
=
[]
out_top_logprobs
:
List
[
Optional
[
Dict
[
str
,
float
]]]
=
[]
last_token_len
=
0
for
i
,
token_id
in
enumerate
(
token_ids
):
step_top_logprobs
=
top_logprobs
[
i
]
if
step_top_logprobs
is
None
:
token
=
self
.
tokenizer
.
decode
(
token_id
)
out_tokens
.
append
(
token
)
out_token_logprobs
.
append
(
None
)
out_top_logprobs
.
append
(
None
)
else
:
token
=
self
.
_get_decoded_token
(
step_top_logprobs
[
token_id
],
token_id
)
token_logprob
=
max
(
step_top_logprobs
[
token_id
].
logprob
,
-
9999.0
)
out_tokens
.
append
(
token
)
out_token_logprobs
.
append
(
token_logprob
)
# makes sure to add the top num_output_top_logprobs + 1
# logprobs, as defined in the openai API
# (cf. https://github.com/openai/openai-openapi/blob/
# 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
out_top_logprobs
.
append
({
# Convert float("-inf") to the
# JSON-serializable float that OpenAI uses
self
.
_get_decoded_token
(
top_lp
[
1
],
top_lp
[
0
]):
max
(
top_lp
[
1
].
logprob
,
-
9999.0
)
for
i
,
top_lp
in
enumerate
(
step_top_logprobs
.
items
())
if
num_output_top_logprobs
>=
i
})
if
len
(
out_text_offset
)
==
0
:
out_text_offset
.
append
(
initial_text_offset
)
else
:
out_text_offset
.
append
(
out_text_offset
[
-
1
]
+
last_token_len
)
last_token_len
=
len
(
token
)
return
CompletionLogProbs
(
text_offset
=
out_text_offset
,
token_logprobs
=
out_token_logprobs
,
tokens
=
out_tokens
,
top_logprobs
=
out_top_logprobs
,
)
vllm/entrypoints/openai/serving_engine.py
View file @
87d41c84
...
...
@@ -11,7 +11,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
CompletionRequest
,
EmbeddingRequest
,
ErrorResponse
,
LogProbs
,
ModelCard
,
ModelList
,
ModelCard
,
ModelList
,
ModelPermission
)
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
...
...
@@ -75,51 +75,6 @@ class OpenAIServing:
model_cards
.
extend
(
lora_cards
)
return
ModelList
(
data
=
model_cards
)
def
_create_logprobs
(
self
,
token_ids
:
List
[
int
],
top_logprobs
:
List
[
Optional
[
Dict
[
int
,
Logprob
]]],
num_output_top_logprobs
:
Optional
[
int
]
=
None
,
initial_text_offset
:
int
=
0
,
)
->
LogProbs
:
"""Create OpenAI-style logprobs."""
logprobs
=
LogProbs
()
last_token_len
=
0
if
num_output_top_logprobs
:
logprobs
.
top_logprobs
=
[]
for
i
,
token_id
in
enumerate
(
token_ids
):
step_top_logprobs
=
top_logprobs
[
i
]
if
step_top_logprobs
is
None
:
token
=
self
.
tokenizer
.
decode
(
token_id
)
logprobs
.
tokens
.
append
(
token
)
logprobs
.
token_logprobs
.
append
(
None
)
assert
logprobs
.
top_logprobs
is
not
None
logprobs
.
top_logprobs
.
append
(
None
)
else
:
token_logprob
=
step_top_logprobs
[
token_id
].
logprob
token
=
step_top_logprobs
[
token_id
].
decoded_token
logprobs
.
tokens
.
append
(
token
)
token_logprob
=
max
(
token_logprob
,
-
9999.0
)
logprobs
.
token_logprobs
.
append
(
token_logprob
)
if
num_output_top_logprobs
:
assert
logprobs
.
top_logprobs
is
not
None
logprobs
.
top_logprobs
.
append
({
# Convert float("-inf") to the
# JSON-serializable float that OpenAI uses
p
.
decoded_token
:
max
(
p
.
logprob
,
-
9999.0
)
for
i
,
p
in
step_top_logprobs
.
items
()
}
if
step_top_logprobs
else
None
)
if
len
(
logprobs
.
text_offset
)
==
0
:
logprobs
.
text_offset
.
append
(
initial_text_offset
)
else
:
logprobs
.
text_offset
.
append
(
logprobs
.
text_offset
[
-
1
]
+
last_token_len
)
last_token_len
=
len
(
token
)
return
logprobs
def
create_error_response
(
self
,
message
:
str
,
...
...
@@ -235,3 +190,8 @@ class OpenAIServing:
f
"Please reduce the length of the messages or completion."
,
)
else
:
return
input_ids
,
input_text
def
_get_decoded_token
(
self
,
logprob
:
Logprob
,
token_id
:
int
)
->
str
:
if
logprob
.
decoded_token
is
not
None
:
return
logprob
.
decoded_token
return
self
.
tokenizer
.
decode
(
token_id
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment