Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
dd7e8f5f
Unverified
Commit
dd7e8f5f
authored
Jan 18, 2024
by
Simon Mo
Committed by
GitHub
Jan 18, 2024
Browse files
refactor complemention api for readability (#2499)
parent
d2a68364
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
284 additions
and
253 deletions
+284
-253
tests/entrypoints/test_openai_server.py
tests/entrypoints/test_openai_server.py
+10
-0
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+45
-0
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+3
-25
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+210
-215
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+15
-12
vllm/model_executor/weight_utils.py
vllm/model_executor/weight_utils.py
+1
-1
No files found.
tests/entrypoints/test_openai_server.py
View file @
dd7e8f5f
...
@@ -88,6 +88,16 @@ async def test_single_completion(server, client: openai.AsyncOpenAI):
...
@@ -88,6 +88,16 @@ async def test_single_completion(server, client: openai.AsyncOpenAI):
assert
completion
.
usage
==
openai
.
types
.
CompletionUsage
(
assert
completion
.
usage
==
openai
.
types
.
CompletionUsage
(
completion_tokens
=
5
,
prompt_tokens
=
6
,
total_tokens
=
11
)
completion_tokens
=
5
,
prompt_tokens
=
6
,
total_tokens
=
11
)
# test using token IDs
completion
=
await
client
.
completions
.
create
(
model
=
MODEL_NAME
,
prompt
=
[
0
,
0
,
0
,
0
,
0
],
max_tokens
=
5
,
temperature
=
0.0
,
)
assert
completion
.
choices
[
0
].
text
is
not
None
and
len
(
completion
.
choices
[
0
].
text
)
>=
5
async
def
test_single_chat_session
(
server
,
client
:
openai
.
AsyncOpenAI
):
async
def
test_single_chat_session
(
server
,
client
:
openai
.
AsyncOpenAI
):
messages
=
[{
messages
=
[{
...
...
vllm/entrypoints/openai/protocol.py
View file @
dd7e8f5f
...
@@ -6,6 +6,7 @@ from typing import Dict, List, Literal, Optional, Union
...
@@ -6,6 +6,7 @@ from typing import Dict, List, Literal, Optional, Union
from
pydantic
import
BaseModel
,
Field
from
pydantic
import
BaseModel
,
Field
from
vllm.utils
import
random_uuid
from
vllm.utils
import
random_uuid
from
vllm.sampling_params
import
SamplingParams
class
ErrorResponse
(
BaseModel
):
class
ErrorResponse
(
BaseModel
):
...
@@ -78,6 +79,26 @@ class ChatCompletionRequest(BaseModel):
...
@@ -78,6 +79,26 @@ class ChatCompletionRequest(BaseModel):
repetition_penalty
:
Optional
[
float
]
=
1.0
repetition_penalty
:
Optional
[
float
]
=
1.0
min_p
:
Optional
[
float
]
=
0.0
min_p
:
Optional
[
float
]
=
0.0
def
to_sampling_params
(
self
)
->
SamplingParams
:
return
SamplingParams
(
n
=
self
.
n
,
presence_penalty
=
self
.
presence_penalty
,
frequency_penalty
=
self
.
frequency_penalty
,
repetition_penalty
=
self
.
repetition_penalty
,
temperature
=
self
.
temperature
,
top_p
=
self
.
top_p
,
min_p
=
self
.
min_p
,
stop
=
self
.
stop
,
stop_token_ids
=
self
.
stop_token_ids
,
max_tokens
=
self
.
max_tokens
,
best_of
=
self
.
best_of
,
top_k
=
self
.
top_k
,
ignore_eos
=
self
.
ignore_eos
,
use_beam_search
=
self
.
use_beam_search
,
skip_special_tokens
=
self
.
skip_special_tokens
,
spaces_between_special_tokens
=
self
.
spaces_between_special_tokens
,
)
class
CompletionRequest
(
BaseModel
):
class
CompletionRequest
(
BaseModel
):
model
:
str
model
:
str
...
@@ -107,6 +128,30 @@ class CompletionRequest(BaseModel):
...
@@ -107,6 +128,30 @@ class CompletionRequest(BaseModel):
repetition_penalty
:
Optional
[
float
]
=
1.0
repetition_penalty
:
Optional
[
float
]
=
1.0
min_p
:
Optional
[
float
]
=
0.0
min_p
:
Optional
[
float
]
=
0.0
def
to_sampling_params
(
self
):
echo_without_generation
=
self
.
echo
and
self
.
max_tokens
==
0
return
SamplingParams
(
n
=
self
.
n
,
best_of
=
self
.
best_of
,
presence_penalty
=
self
.
presence_penalty
,
frequency_penalty
=
self
.
frequency_penalty
,
repetition_penalty
=
self
.
repetition_penalty
,
temperature
=
self
.
temperature
,
top_p
=
self
.
top_p
,
top_k
=
self
.
top_k
,
min_p
=
self
.
min_p
,
stop
=
self
.
stop
,
stop_token_ids
=
self
.
stop_token_ids
,
ignore_eos
=
self
.
ignore_eos
,
max_tokens
=
self
.
max_tokens
if
not
echo_without_generation
else
1
,
logprobs
=
self
.
logprobs
,
use_beam_search
=
self
.
use_beam_search
,
prompt_logprobs
=
self
.
logprobs
if
self
.
echo
else
None
,
skip_special_tokens
=
self
.
skip_special_tokens
,
spaces_between_special_tokens
=
(
self
.
spaces_between_special_tokens
),
)
class
LogProbs
(
BaseModel
):
class
LogProbs
(
BaseModel
):
text_offset
:
List
[
int
]
=
Field
(
default_factory
=
list
)
text_offset
:
List
[
int
]
=
Field
(
default_factory
=
list
)
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
dd7e8f5f
...
@@ -11,7 +11,6 @@ from vllm.entrypoints.openai.protocol import (
...
@@ -11,7 +11,6 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionStreamResponse
,
ChatMessage
,
DeltaMessage
,
ErrorResponse
,
ChatCompletionStreamResponse
,
ChatMessage
,
DeltaMessage
,
ErrorResponse
,
UsageInfo
)
UsageInfo
)
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
vllm.entrypoints.openai.serving_engine
import
OpenAIServing
from
vllm.entrypoints.openai.serving_engine
import
OpenAIServing
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -60,32 +59,11 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -60,32 +59,11 @@ class OpenAIServingChat(OpenAIServing):
f
"Error in applying chat template from request:
{
str
(
e
)
}
"
)
f
"Error in applying chat template from request:
{
str
(
e
)
}
"
)
return
self
.
create_error_response
(
str
(
e
))
return
self
.
create_error_response
(
str
(
e
))
token_ids
,
error_check_ret
=
await
self
.
_check_length
(
request
,
prompt
=
prompt
)
if
error_check_ret
is
not
None
:
return
error_check_ret
request_id
=
f
"cmpl-
{
random_uuid
()
}
"
request_id
=
f
"cmpl-
{
random_uuid
()
}
"
try
:
try
:
spaces_between_special_tokens
=
request
.
spaces_between_special_tokens
token_ids
=
self
.
_validate_prompt_and_tokenize
(
request
,
sampling_params
=
SamplingParams
(
prompt
=
prompt
)
n
=
request
.
n
,
sampling_params
=
request
.
to_sampling_params
()
presence_penalty
=
request
.
presence_penalty
,
frequency_penalty
=
request
.
frequency_penalty
,
repetition_penalty
=
request
.
repetition_penalty
,
temperature
=
request
.
temperature
,
top_p
=
request
.
top_p
,
min_p
=
request
.
min_p
,
stop
=
request
.
stop
,
stop_token_ids
=
request
.
stop_token_ids
,
max_tokens
=
request
.
max_tokens
,
best_of
=
request
.
best_of
,
top_k
=
request
.
top_k
,
ignore_eos
=
request
.
ignore_eos
,
use_beam_search
=
request
.
use_beam_search
,
skip_special_tokens
=
request
.
skip_special_tokens
,
spaces_between_special_tokens
=
spaces_between_special_tokens
,
)
except
ValueError
as
e
:
except
ValueError
as
e
:
return
self
.
create_error_response
(
str
(
e
))
return
self
.
create_error_response
(
str
(
e
))
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
dd7e8f5f
import
time
import
time
from
fastapi
import
Request
from
fastapi
import
Request
from
typing
import
AsyncGenerator
,
Optional
from
typing
import
AsyncGenerator
,
AsyncIterator
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
random_uuid
from
vllm.utils
import
random_uuid
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
.protocol
import
(
CompletionRequest
,
CompletionResponse
,
from
.protocol
import
(
CompletionRequest
,
CompletionResponse
,
CompletionResponseChoice
,
CompletionResponseChoice
,
CompletionResponseStreamChoice
,
CompletionResponseStreamChoice
,
CompletionStreamResponse
,
LogProbs
,
UsageInfo
)
CompletionStreamResponse
,
LogProbs
,
UsageInfo
,
)
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
vllm.entrypoints.openai.serving_engine
import
OpenAIServing
from
vllm.entrypoints.openai.serving_engine
import
OpenAIServing
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
class
OpenAIServingCompletion
(
OpenAIServing
):
async
def
completion_stream_generator
(
request
:
CompletionRequest
,
def
__init__
(
self
,
engine
:
AsyncLLMEngine
,
served_model
:
str
):
result_generator
:
AsyncIterator
[
RequestOutput
],
super
().
__init__
(
engine
=
engine
,
served_model
=
served_model
)
echo_without_generation
,
create_logprobs_fn
,
request_id
,
created_time
,
model_name
)
->
AsyncGenerator
[
str
,
None
]:
async
def
create_completion
(
self
,
request
:
CompletionRequest
,
raw_request
:
Request
):
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/completions/create
for the API specification. This API mimics the OpenAI Completion API.
NOTE: Currently we do not support the following features:
- suffix (the language models we currently support do not support
suffix)
- logit_bias (to be supported by vLLM engine)
"""
error_check_ret
=
await
self
.
_check_model
(
request
)
if
error_check_ret
is
not
None
:
return
error_check_ret
# OpenAI API supports echoing the prompt when max_tokens is 0.
echo_without_generation
=
request
.
echo
and
request
.
max_tokens
==
0
if
request
.
suffix
is
not
None
:
# The language models we currently support do not support suffix.
return
self
.
create_error_response
(
"suffix is not currently supported"
)
if
request
.
logit_bias
is
not
None
and
len
(
request
.
logit_bias
)
>
0
:
# TODO: support logit_bias in vLLM engine.
return
self
.
create_error_response
(
"logit_bias is not currently supported"
)
model_name
=
request
.
model
request_id
=
f
"cmpl-
{
random_uuid
()
}
"
use_token_ids
=
False
if
isinstance
(
request
.
prompt
,
list
):
if
len
(
request
.
prompt
)
==
0
:
return
self
.
create_error_response
(
"please provide at least one prompt"
)
first_element
=
request
.
prompt
[
0
]
if
isinstance
(
first_element
,
int
):
use_token_ids
=
True
prompt
=
request
.
prompt
elif
isinstance
(
first_element
,
(
str
,
list
)):
# TODO: handles multiple prompt case in list[list[int]]
if
len
(
request
.
prompt
)
>
1
:
return
self
.
create_error_response
(
"multiple prompts in a batch is not currently supported"
)
use_token_ids
=
not
isinstance
(
first_element
,
str
)
prompt
=
request
.
prompt
[
0
]
else
:
prompt
=
request
.
prompt
if
use_token_ids
:
_
,
error_check_ret
=
await
self
.
_check_length
(
request
,
prompt_ids
=
prompt
)
else
:
token_ids
,
error_check_ret
=
await
self
.
_check_length
(
request
,
prompt
=
prompt
)
if
error_check_ret
is
not
None
:
return
error_check_ret
created_time
=
int
(
time
.
monotonic
())
try
:
spaces_between_special_tokens
=
request
.
spaces_between_special_tokens
sampling_params
=
SamplingParams
(
n
=
request
.
n
,
best_of
=
request
.
best_of
,
presence_penalty
=
request
.
presence_penalty
,
frequency_penalty
=
request
.
frequency_penalty
,
repetition_penalty
=
request
.
repetition_penalty
,
temperature
=
request
.
temperature
,
top_p
=
request
.
top_p
,
top_k
=
request
.
top_k
,
min_p
=
request
.
min_p
,
stop
=
request
.
stop
,
stop_token_ids
=
request
.
stop_token_ids
,
ignore_eos
=
request
.
ignore_eos
,
max_tokens
=
request
.
max_tokens
if
not
echo_without_generation
else
1
,
logprobs
=
request
.
logprobs
,
use_beam_search
=
request
.
use_beam_search
,
prompt_logprobs
=
request
.
logprobs
if
request
.
echo
else
None
,
skip_special_tokens
=
request
.
skip_special_tokens
,
spaces_between_special_tokens
=
spaces_between_special_tokens
,
)
except
ValueError
as
e
:
return
self
.
create_error_response
(
str
(
e
))
if
use_token_ids
:
result_generator
=
self
.
engine
.
generate
(
None
,
sampling_params
,
request_id
,
prompt_token_ids
=
prompt
)
else
:
result_generator
=
self
.
engine
.
generate
(
prompt
,
sampling_params
,
request_id
,
token_ids
)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search.
stream
=
(
request
.
stream
and
(
request
.
best_of
is
None
or
request
.
n
==
request
.
best_of
)
and
not
request
.
use_beam_search
)
def
create_stream_response_json
(
index
:
int
,
text
:
str
,
logprobs
:
Optional
[
LogProbs
]
=
None
,
finish_reason
:
Optional
[
str
]
=
None
,
usage
:
Optional
[
UsageInfo
]
=
None
,
)
->
str
:
choice_data
=
CompletionResponseStreamChoice
(
index
=
index
,
text
=
text
,
logprobs
=
logprobs
,
finish_reason
=
finish_reason
,
)
response
=
CompletionStreamResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
choices
=
[
choice_data
],
)
if
usage
is
not
None
:
response
.
usage
=
usage
response_json
=
response
.
json
(
exclude_unset
=
True
,
ensure_ascii
=
False
)
return
response_json
async
def
completion_stream_generator
()
->
AsyncGenerator
[
str
,
None
]:
previous_texts
=
[
""
]
*
request
.
n
previous_texts
=
[
""
]
*
request
.
n
previous_num_tokens
=
[
0
]
*
request
.
n
previous_num_tokens
=
[
0
]
*
request
.
n
has_echoed
=
[
False
]
*
request
.
n
has_echoed
=
[
False
]
*
request
.
n
async
for
res
in
result_generator
:
async
for
res
in
result_generator
:
res
:
RequestOutput
# TODO: handle client disconnect for streaming
for
output
in
res
.
outputs
:
for
output
in
res
.
outputs
:
i
=
output
.
index
i
=
output
.
index
delta_text
=
output
.
text
[
len
(
previous_texts
[
i
]):]
delta_text
=
output
.
text
[
len
(
previous_texts
[
i
]):]
...
@@ -178,7 +52,7 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -178,7 +52,7 @@ class OpenAIServingCompletion(OpenAIServing):
top_logprobs
=
res
.
prompt_logprobs
top_logprobs
=
res
.
prompt_logprobs
has_echoed
[
i
]
=
True
has_echoed
[
i
]
=
True
if
request
.
logprobs
is
not
None
:
if
request
.
logprobs
is
not
None
:
logprobs
=
self
.
_
create_logprobs
(
logprobs
=
create_logprobs
_fn
(
token_ids
=
token_ids
,
token_ids
=
token_ids
,
top_logprobs
=
top_logprobs
,
top_logprobs
=
top_logprobs
,
num_output_top_logprobs
=
request
.
logprobs
,
num_output_top_logprobs
=
request
.
logprobs
,
...
@@ -189,16 +63,22 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -189,16 +63,22 @@ class OpenAIServingCompletion(OpenAIServing):
previous_texts
[
i
]
=
output
.
text
previous_texts
[
i
]
=
output
.
text
previous_num_tokens
[
i
]
=
len
(
output
.
token_ids
)
previous_num_tokens
[
i
]
=
len
(
output
.
token_ids
)
finish_reason
=
output
.
finish_reason
finish_reason
=
output
.
finish_reason
response_json
=
create_stream_response_json
(
response_json
=
CompletionStreamResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
choices
=
[
CompletionResponseStreamChoice
(
index
=
i
,
index
=
i
,
text
=
delta_text
,
text
=
delta_text
,
logprobs
=
logprobs
,
logprobs
=
logprobs
,
finish_reason
=
finish_reason
,
finish_reason
=
finish_reason
,
)
)
]).
json
(
exclude_unset
=
True
,
ensure_ascii
=
False
)
yield
f
"data:
{
response_json
}
\n\n
"
yield
f
"data:
{
response_json
}
\n\n
"
if
output
.
finish_reason
is
not
None
:
if
output
.
finish_reason
is
not
None
:
logprobs
=
(
LogProbs
()
logprobs
=
LogProbs
()
if
request
.
logprobs
is
not
None
else
None
if
request
.
logprobs
is
not
None
else
None
)
prompt_tokens
=
len
(
res
.
prompt_token_ids
)
prompt_tokens
=
len
(
res
.
prompt_token_ids
)
completion_tokens
=
len
(
output
.
token_ids
)
completion_tokens
=
len
(
output
.
token_ids
)
final_usage
=
UsageInfo
(
final_usage
=
UsageInfo
(
...
@@ -206,28 +86,54 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -206,28 +86,54 @@ class OpenAIServingCompletion(OpenAIServing):
completion_tokens
=
completion_tokens
,
completion_tokens
=
completion_tokens
,
total_tokens
=
prompt_tokens
+
completion_tokens
,
total_tokens
=
prompt_tokens
+
completion_tokens
,
)
)
response_json
=
create_stream_response_json
(
response_json
=
CompletionStreamResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
choices
=
[
CompletionResponseStreamChoice
(
index
=
i
,
index
=
i
,
text
=
""
,
text
=
""
,
logprobs
=
logprobs
,
logprobs
=
logprobs
,
finish_reason
=
output
.
finish_reason
,
finish_reason
=
output
.
finish_reason
,
usage
=
final_usage
,
)
)
],
usage
=
final_usage
,
).
json
(
exclude_unset
=
True
,
ensure_ascii
=
False
)
yield
f
"data:
{
response_json
}
\n\n
"
yield
f
"data:
{
response_json
}
\n\n
"
yield
"data: [DONE]
\n\n
"
yield
"data: [DONE]
\n\n
"
# Streaming response
if
stream
:
return
completion_stream_generator
()
# Non-streaming response
def
parse_prompt_format
(
prompt
)
->
tuple
[
bool
,
list
]:
final_res
:
RequestOutput
=
None
# get the prompt, openai supports the following
async
for
res
in
result_generator
:
# "a string, array of strings, array of tokens, or array of token arrays."
if
await
raw_request
.
is_disconnected
():
prompt_is_tokens
=
False
# Abort the request if the client disconnects.
prompts
=
[
prompt
]
# case 1: a string
await
self
.
engine
.
abort
(
request_id
)
if
isinstance
(
prompt
,
list
):
return
self
.
create_error_response
(
"Client disconnected"
)
if
len
(
prompt
)
==
0
:
final_res
=
res
raise
ValueError
(
"please provide at least one prompt"
)
elif
isinstance
(
prompt
[
0
],
str
):
prompt_is_tokens
=
False
prompts
=
prompt
# case 2: array of strings
elif
isinstance
(
prompt
[
0
],
int
):
prompt_is_tokens
=
True
prompts
=
[
prompt
]
# case 3: array of tokens
elif
isinstance
(
prompt
[
0
],
list
)
and
isinstance
(
prompt
[
0
][
0
],
int
):
prompt_is_tokens
=
True
prompts
=
prompt
# case 4: array of token arrays
else
:
raise
ValueError
(
"prompt must be a string, array of strings, array of tokens, or array of token arrays"
)
return
prompt_is_tokens
,
prompts
def
request_output_to_completion_response
(
final_res
:
RequestOutput
,
request
,
echo_without_generation
,
create_logprobs_fn
,
request_id
,
created_time
,
model_name
)
->
CompletionResponse
:
assert
final_res
is
not
None
assert
final_res
is
not
None
choices
=
[]
choices
=
[]
prompt_token_ids
=
final_res
.
prompt_token_ids
prompt_token_ids
=
final_res
.
prompt_token_ids
...
@@ -244,7 +150,7 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -244,7 +150,7 @@ class OpenAIServingCompletion(OpenAIServing):
else
:
else
:
token_ids
=
prompt_token_ids
token_ids
=
prompt_token_ids
top_logprobs
=
prompt_logprobs
top_logprobs
=
prompt_logprobs
logprobs
=
self
.
_
create_logprobs
(
logprobs
=
create_logprobs
_fn
(
token_ids
=
token_ids
,
token_ids
=
token_ids
,
top_logprobs
=
top_logprobs
,
top_logprobs
=
top_logprobs
,
num_output_top_logprobs
=
request
.
logprobs
,
num_output_top_logprobs
=
request
.
logprobs
,
...
@@ -273,7 +179,8 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -273,7 +179,8 @@ class OpenAIServingCompletion(OpenAIServing):
completion_tokens
=
num_generated_tokens
,
completion_tokens
=
num_generated_tokens
,
total_tokens
=
num_prompt_tokens
+
num_generated_tokens
,
total_tokens
=
num_prompt_tokens
+
num_generated_tokens
,
)
)
response
=
CompletionResponse
(
return
CompletionResponse
(
id
=
request_id
,
id
=
request_id
,
created
=
created_time
,
created
=
created_time
,
model
=
model_name
,
model
=
model_name
,
...
@@ -281,9 +188,97 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -281,9 +188,97 @@ class OpenAIServingCompletion(OpenAIServing):
usage
=
usage
,
usage
=
usage
,
)
)
if
request
.
stream
:
class
OpenAIServingCompletion
(
OpenAIServing
):
def
__init__
(
self
,
engine
:
AsyncLLMEngine
,
served_model
:
str
):
super
().
__init__
(
engine
=
engine
,
served_model
=
served_model
)
async
def
create_completion
(
self
,
request
:
CompletionRequest
,
raw_request
:
Request
):
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/completions/create
for the API specification. This API mimics the OpenAI Completion API.
NOTE: Currently we do not support the following features:
- suffix (the language models we currently support do not support
suffix)
- logit_bias (to be supported by vLLM engine)
"""
error_check_ret
=
await
self
.
_check_model
(
request
)
if
error_check_ret
is
not
None
:
return
error_check_ret
# OpenAI API supports echoing the prompt when max_tokens is 0.
echo_without_generation
=
request
.
echo
and
request
.
max_tokens
==
0
# Return error for unsupported features.
if
request
.
suffix
is
not
None
:
return
self
.
create_error_response
(
"suffix is not currently supported"
)
if
request
.
logit_bias
is
not
None
and
len
(
request
.
logit_bias
)
>
0
:
return
self
.
create_error_response
(
"logit_bias is not currently supported"
)
model_name
=
request
.
model
request_id
=
f
"cmpl-
{
random_uuid
()
}
"
created_time
=
int
(
time
.
monotonic
())
# Schedule the request and get the result generator.
try
:
sampling_params
=
request
.
to_sampling_params
()
prompt_is_tokens
,
prompts
=
parse_prompt_format
(
request
.
prompt
)
if
len
(
prompts
)
>
1
:
raise
ValueError
(
"Batching in completion API is not supported."
)
prompt
=
prompts
[
0
]
if
prompt_is_tokens
:
input_ids
=
self
.
_validate_prompt_and_tokenize
(
request
,
prompt_ids
=
prompt
)
else
:
input_ids
=
self
.
_validate_prompt_and_tokenize
(
request
,
prompt
=
prompt
)
result_generator
=
self
.
engine
.
generate
(
None
,
sampling_params
,
request_id
,
prompt_token_ids
=
input_ids
)
except
ValueError
as
e
:
return
self
.
create_error_response
(
str
(
e
))
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search.
stream
=
(
request
.
stream
and
(
request
.
best_of
is
None
or
request
.
n
==
request
.
best_of
)
and
not
request
.
use_beam_search
)
# Streaming response
if
stream
:
return
completion_stream_generator
(
request
,
result_generator
,
echo_without_generation
,
self
.
_create_logprobs
,
request_id
,
created_time
,
model_name
)
# Non-streaming response
final_res
:
RequestOutput
=
None
async
for
res
in
result_generator
:
if
await
raw_request
.
is_disconnected
():
# Abort the request if the client disconnects.
await
self
.
engine
.
abort
(
request_id
)
return
self
.
create_error_response
(
"Client disconnected"
)
final_res
=
res
response
=
request_output_to_completion_response
(
final_res
,
request
,
echo_without_generation
,
self
.
_create_logprobs
,
request_id
,
created_time
,
model_name
)
# When user requests streaming but we don't stream, we still need to
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
# return a streaming response with a single event.
if
request
.
stream
:
response_json
=
response
.
json
(
ensure_ascii
=
False
)
response_json
=
response
.
json
(
ensure_ascii
=
False
)
async
def
fake_stream_generator
()
->
AsyncGenerator
[
str
,
None
]:
async
def
fake_stream_generator
()
->
AsyncGenerator
[
str
,
None
]:
...
...
vllm/entrypoints/openai/serving_engine.py
View file @
dd7e8f5f
import
asyncio
import
asyncio
from
http
import
HTTPStatus
from
http
import
HTTPStatus
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Union
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
...
@@ -104,27 +104,30 @@ class OpenAIServing:
...
@@ -104,27 +104,30 @@ class OpenAIServing:
err_type
=
"NotFoundError"
,
err_type
=
"NotFoundError"
,
status_code
=
HTTPStatus
.
NOT_FOUND
)
status_code
=
HTTPStatus
.
NOT_FOUND
)
async
def
_check_length
(
def
_validate_prompt_and_tokenize
(
self
,
self
,
request
:
Union
[
ChatCompletionRequest
,
CompletionRequest
],
request
:
Union
[
ChatCompletionRequest
,
CompletionRequest
],
prompt
:
Optional
[
str
]
=
None
,
prompt
:
Optional
[
str
]
=
None
,
prompt_ids
:
Optional
[
List
[
int
]]
=
None
prompt_ids
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
int
]:
)
->
Tuple
[
List
[
int
],
Optional
[
ErrorResponse
]]:
if
not
(
prompt
or
prompt_ids
):
assert
(
not
(
prompt
is
None
and
prompt_ids
is
None
)
raise
ValueError
(
"Either prompt or prompt_ids should be provided."
)
and
not
(
prompt
is
not
None
and
prompt_ids
is
not
None
)
if
(
prompt
and
prompt_ids
):
),
"Either prompt or prompt_ids should be provided."
raise
ValueError
(
"Only one of prompt or prompt_ids should be provided."
)
input_ids
=
prompt_ids
if
prompt_ids
is
not
None
else
self
.
tokenizer
(
input_ids
=
prompt_ids
if
prompt_ids
is
not
None
else
self
.
tokenizer
(
prompt
).
input_ids
prompt
).
input_ids
token_num
=
len
(
input_ids
)
token_num
=
len
(
input_ids
)
if
request
.
max_tokens
is
None
:
if
request
.
max_tokens
is
None
:
request
.
max_tokens
=
self
.
max_model_len
-
token_num
request
.
max_tokens
=
self
.
max_model_len
-
token_num
if
token_num
+
request
.
max_tokens
>
self
.
max_model_len
:
if
token_num
+
request
.
max_tokens
>
self
.
max_model_len
:
r
eturn
input_ids
,
self
.
create_error_response
(
r
aise
ValueError
(
f
"This model's maximum context length is
{
self
.
max_model_len
}
tokens. "
f
"This model's maximum context length is
{
self
.
max_model_len
}
tokens. "
f
"However, you requested
{
request
.
max_tokens
+
token_num
}
tokens "
f
"However, you requested
{
request
.
max_tokens
+
token_num
}
tokens "
f
"(
{
token_num
}
in the messages, "
f
"(
{
token_num
}
in the messages, "
f
"
{
request
.
max_tokens
}
in the completion). "
f
"
{
request
.
max_tokens
}
in the completion). "
f
"Please reduce the length of the messages or completion."
,
)
f
"Please reduce the length of the messages or completion."
,
)
else
:
else
:
return
input_ids
,
None
return
input_ids
vllm/model_executor/weight_utils.py
View file @
dd7e8f5f
...
@@ -163,7 +163,7 @@ def prepare_hf_model_weights(
...
@@ -163,7 +163,7 @@ def prepare_hf_model_weights(
use_safetensors
=
True
use_safetensors
=
True
break
break
logger
.
info
(
f
"
Download
ing model weights
{
allow_patterns
}
"
)
logger
.
info
(
f
"
Us
ing model weights
format
{
allow_patterns
}
"
)
# Use file lock to prevent multiple processes from
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
# downloading the same model weights at the same time.
with
get_lock
(
model_name_or_path
,
cache_dir
):
with
get_lock
(
model_name_or_path
,
cache_dir
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment